|
import unittest |
|
from onmt.translate.beam_search import BeamSearch, GNMTGlobalScorer |
|
from onmt.translate.beam_search import BeamSearchLM |
|
|
|
from copy import deepcopy |
|
|
|
import torch |
|
|
|
|
|
class GlobalScorerStub(object): |
|
alpha = 0 |
|
beta = 0 |
|
|
|
def __init__(self): |
|
self.length_penalty = lambda x, alpha: 1.0 |
|
self.cov_penalty = lambda cov, beta: torch.zeros( |
|
(1, cov.shape[-2]), device=cov.device, dtype=torch.float |
|
) |
|
self.has_cov_pen = False |
|
self.has_len_pen = False |
|
|
|
def update_global_state(self, beam): |
|
pass |
|
|
|
def score(self, beam, scores): |
|
return scores |
|
|
|
|
|
class TestBeamSearch(unittest.TestCase): |
|
BLOCKED_SCORE = -10e20 |
|
|
|
def test_advance_with_all_repeats_gets_blocked(self): |
|
|
|
beam_sz = 5 |
|
n_words = 100 |
|
repeat_idx = 47 |
|
ngram_repeat = 3 |
|
device_init = torch.zeros(1, 1) |
|
for batch_sz in [1, 3]: |
|
beam = BeamSearch( |
|
beam_sz, |
|
batch_sz, |
|
0, |
|
1, |
|
2, |
|
3, |
|
1, |
|
2, |
|
GlobalScorerStub(), |
|
0, |
|
30, |
|
False, |
|
ngram_repeat, |
|
set(), |
|
False, |
|
0.0, |
|
False, |
|
) |
|
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,))) |
|
for i in range(ngram_repeat + 4): |
|
|
|
word_probs = torch.full((batch_sz * beam_sz, n_words), -float("inf")) |
|
word_probs[0::beam_sz, repeat_idx] = 0 |
|
|
|
attns = torch.randn(batch_sz * beam_sz, 1, 53) |
|
beam.advance(word_probs, attns) |
|
|
|
if i < ngram_repeat: |
|
|
|
expected_scores = torch.tensor( |
|
[0] + [-float("inf")] * (beam_sz - 1) |
|
).repeat(batch_sz, 1) |
|
self.assertTrue(beam.topk_log_probs.equal(expected_scores)) |
|
elif i % ngram_repeat == 0: |
|
|
|
|
|
|
|
expected_scores = torch.tensor( |
|
[self.BLOCKED_SCORE] + [-float("inf")] * (beam_sz - 1) |
|
).repeat(batch_sz, 1) |
|
self.assertTrue(beam.topk_log_probs.equal(expected_scores)) |
|
else: |
|
|
|
|
|
|
|
|
|
expected_scores = torch.tensor( |
|
[self.BLOCKED_SCORE] + [-float("inf")] * (beam_sz - 1) |
|
).repeat(batch_sz, 1) |
|
self.assertTrue(beam.topk_log_probs.equal(expected_scores)) |
|
|
|
def test_advance_with_some_repeats_gets_blocked(self): |
|
|
|
beam_sz = 5 |
|
n_words = 100 |
|
repeat_idx = 47 |
|
ngram_repeat = 3 |
|
no_repeat_score = -2.3 |
|
repeat_score = -0.1 |
|
device_init = torch.zeros(1, 1) |
|
for batch_sz in [1, 3]: |
|
beam = BeamSearch( |
|
beam_sz, |
|
batch_sz, |
|
0, |
|
1, |
|
2, |
|
3, |
|
1, |
|
2, |
|
GlobalScorerStub(), |
|
0, |
|
30, |
|
False, |
|
ngram_repeat, |
|
set(), |
|
False, |
|
0.0, |
|
False, |
|
) |
|
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,))) |
|
for i in range(ngram_repeat + 4): |
|
|
|
word_probs = torch.full((batch_sz * beam_sz, n_words), -float("inf")) |
|
if i == 0: |
|
|
|
|
|
|
|
word_probs[0::beam_sz, repeat_idx] = repeat_score |
|
word_probs[0::beam_sz, repeat_idx + i + 1] = no_repeat_score |
|
else: |
|
|
|
word_probs[0::beam_sz, repeat_idx] = 0 |
|
|
|
word_probs[1::beam_sz, repeat_idx + i + 1] = 0 |
|
attns = torch.randn(batch_sz * beam_sz, 1, 53) |
|
beam.advance(word_probs, attns) |
|
if i < ngram_repeat: |
|
self.assertFalse( |
|
beam.topk_log_probs[0::beam_sz].eq(self.BLOCKED_SCORE).any() |
|
) |
|
self.assertFalse( |
|
beam.topk_log_probs[1::beam_sz].eq(self.BLOCKED_SCORE).any() |
|
) |
|
elif i == ngram_repeat: |
|
|
|
self.assertFalse( |
|
beam.topk_log_probs[:, 0].eq(self.BLOCKED_SCORE).any() |
|
) |
|
|
|
expected = torch.full([batch_sz, beam_sz], float("-inf")) |
|
expected[:, 0] = no_repeat_score |
|
expected[:, 1] = self.BLOCKED_SCORE |
|
self.assertTrue(beam.topk_log_probs[:, :].equal(expected)) |
|
else: |
|
|
|
self.assertFalse( |
|
beam.topk_log_probs[:, 0].eq(self.BLOCKED_SCORE).any() |
|
) |
|
|
|
expected = torch.full([batch_sz, beam_sz], float("-inf")) |
|
expected[:, 0] = no_repeat_score |
|
expected[:, 1:3] = self.BLOCKED_SCORE |
|
expected[:, 3:] = float("-inf") |
|
self.assertTrue(beam.topk_log_probs.equal(expected)) |
|
|
|
def test_repeating_excluded_index_does_not_die(self): |
|
|
|
beam_sz = 5 |
|
n_words = 100 |
|
repeat_idx = 47 |
|
repeat_idx_ignored = 7 |
|
ngram_repeat = 3 |
|
device_init = torch.zeros(1, 1) |
|
for batch_sz in [1, 3]: |
|
beam = BeamSearch( |
|
beam_sz, |
|
batch_sz, |
|
0, |
|
1, |
|
2, |
|
3, |
|
1, |
|
2, |
|
GlobalScorerStub(), |
|
0, |
|
30, |
|
False, |
|
ngram_repeat, |
|
{repeat_idx_ignored}, |
|
False, |
|
0.0, |
|
False, |
|
) |
|
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,))) |
|
for i in range(ngram_repeat + 4): |
|
|
|
word_probs = torch.full((batch_sz * beam_sz, n_words), -float("inf")) |
|
if i == 0: |
|
word_probs[0::beam_sz, repeat_idx] = -0.1 |
|
word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3 |
|
word_probs[0::beam_sz, repeat_idx_ignored] = -5.0 |
|
else: |
|
|
|
word_probs[0::beam_sz, repeat_idx] = 0 |
|
|
|
word_probs[1::beam_sz, repeat_idx + i + 1] = 0 |
|
|
|
word_probs[2::beam_sz, repeat_idx_ignored] = 0 |
|
attns = torch.randn(batch_sz * beam_sz, 1, 53) |
|
beam.advance(word_probs, attns) |
|
if i < ngram_repeat: |
|
self.assertFalse( |
|
beam.topk_log_probs[:, 0].eq(self.BLOCKED_SCORE).any() |
|
) |
|
self.assertFalse( |
|
beam.topk_log_probs[:, 1].eq(self.BLOCKED_SCORE).any() |
|
) |
|
self.assertFalse( |
|
beam.topk_log_probs[:, 2].eq(self.BLOCKED_SCORE).any() |
|
) |
|
else: |
|
|
|
|
|
self.assertFalse( |
|
beam.topk_log_probs[:, 0].eq(self.BLOCKED_SCORE).any() |
|
) |
|
|
|
|
|
|
|
self.assertTrue(beam.topk_log_probs[:, 0].eq(-2.3).all()) |
|
self.assertFalse( |
|
beam.topk_log_probs[:, 1].eq(self.BLOCKED_SCORE).all() |
|
) |
|
self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all()) |
|
|
|
self.assertTrue( |
|
beam.topk_log_probs[:, 2].eq(self.BLOCKED_SCORE).all() |
|
) |
|
|
|
def test_doesnt_predict_eos_if_shorter_than_min_len(self): |
|
|
|
|
|
for batch_sz in [1, 3]: |
|
beam_sz = 5 |
|
n_words = 100 |
|
_non_eos_idxs = [47, 51, 13, 88, 99] |
|
valid_score_dist = torch.log_softmax( |
|
torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0 |
|
) |
|
min_length = 5 |
|
eos_idx = 2 |
|
lengths = torch.randint(0, 30, (batch_sz,)) |
|
beam = BeamSearch( |
|
beam_sz, |
|
batch_sz, |
|
0, |
|
1, |
|
2, |
|
3, |
|
1, |
|
2, |
|
GlobalScorerStub(), |
|
min_length, |
|
30, |
|
False, |
|
0, |
|
set(), |
|
False, |
|
0.0, |
|
False, |
|
) |
|
device_init = torch.zeros(1, 1) |
|
beam.initialize(device_init, lengths) |
|
all_attns = [] |
|
for i in range(min_length + 4): |
|
|
|
word_probs = torch.full((batch_sz * beam_sz, n_words), -float("inf")) |
|
if i == 0: |
|
|
|
word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] |
|
|
|
|
|
for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): |
|
word_probs[0::beam_sz, j] = score |
|
else: |
|
|
|
word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] |
|
|
|
for k, (j, score) in enumerate( |
|
zip(_non_eos_idxs, valid_score_dist[1:]) |
|
): |
|
beam_idx = min(beam_sz - 1, k) |
|
word_probs[beam_idx::beam_sz, j] = score |
|
|
|
attns = torch.randn(batch_sz * beam_sz, 1, 53) |
|
all_attns.append(attns) |
|
beam.advance(word_probs, attns) |
|
if i < min_length: |
|
expected_score_dist = (i + 1) * valid_score_dist[1:].unsqueeze(0) |
|
self.assertTrue(beam.topk_log_probs.allclose(expected_score_dist)) |
|
elif i == min_length: |
|
|
|
self.assertTrue(beam.is_finished[:, 0].eq(1).all()) |
|
self.assertTrue(beam.is_finished[:, 1:].eq(0).all()) |
|
else: |
|
|
|
|
|
pass |
|
|
|
def test_beam_is_done_when_X_beams_eos_using_min_length(self): |
|
|
|
|
|
beam_sz = 5 |
|
batch_sz = 3 |
|
n_words = 100 |
|
_non_eos_idxs = [47, 51, 13, 88, 99] |
|
valid_score_dist = torch.log_softmax( |
|
torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0 |
|
) |
|
min_length = 5 |
|
eos_idx = 2 |
|
beam = BeamSearch( |
|
beam_sz, |
|
batch_sz, |
|
0, |
|
1, |
|
2, |
|
3, |
|
1, |
|
2, |
|
GlobalScorerStub(), |
|
min_length, |
|
30, |
|
False, |
|
0, |
|
set(), |
|
False, |
|
0.0, |
|
False, |
|
) |
|
device_init = torch.zeros(1, 1) |
|
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,))) |
|
for i in range(min_length + 4): |
|
|
|
word_probs = torch.full((batch_sz * beam_sz, n_words), -float("inf")) |
|
if i == 0: |
|
|
|
word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] |
|
|
|
|
|
for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): |
|
word_probs[0::beam_sz, j] = score |
|
elif i <= min_length: |
|
|
|
word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] |
|
|
|
for k, (j, score) in enumerate( |
|
zip(_non_eos_idxs, valid_score_dist[1:]) |
|
): |
|
beam_idx = min(beam_sz - 1, k) |
|
word_probs[beam_idx::beam_sz, j] = score |
|
else: |
|
for j in range(beam_sz): |
|
word_probs[j::beam_sz, eos_idx] = valid_score_dist[0] |
|
|
|
attns = torch.randn(batch_sz * beam_sz, 1, 53) |
|
beam.advance(word_probs, attns) |
|
if i < min_length: |
|
self.assertFalse(beam.done) |
|
elif i == min_length: |
|
|
|
self.assertTrue(beam.is_finished[:, 1].all()) |
|
beam.update_finished() |
|
self.assertFalse(beam.done) |
|
else: |
|
|
|
self.assertTrue(beam.is_finished[:, 0].all()) |
|
beam.update_finished() |
|
self.assertTrue(beam.done) |
|
|
|
def test_beam_returns_attn_with_correct_length(self): |
|
beam_sz = 5 |
|
batch_sz = 3 |
|
n_words = 100 |
|
_non_eos_idxs = [47, 51, 13, 88, 99] |
|
valid_score_dist = torch.log_softmax( |
|
torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0 |
|
) |
|
min_length = 5 |
|
eos_idx = 2 |
|
inp_lens = torch.randint(1, 30, (batch_sz,)) |
|
beam = BeamSearch( |
|
beam_sz, |
|
batch_sz, |
|
0, |
|
1, |
|
2, |
|
3, |
|
1, |
|
2, |
|
GlobalScorerStub(), |
|
min_length, |
|
30, |
|
True, |
|
0, |
|
set(), |
|
False, |
|
0.0, |
|
False, |
|
) |
|
device_init = torch.zeros(1, 1) |
|
_, _, inp_lens, _ = beam.initialize(device_init, inp_lens) |
|
|
|
for i in range(min_length + 2): |
|
|
|
word_probs = torch.full((batch_sz * beam_sz, n_words), -float("inf")) |
|
if i == 0: |
|
|
|
word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] |
|
|
|
|
|
for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): |
|
word_probs[0::beam_sz, j] = score |
|
elif i <= min_length: |
|
|
|
word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] |
|
|
|
for k, (j, score) in enumerate( |
|
zip(_non_eos_idxs, valid_score_dist[1:]) |
|
): |
|
beam_idx = min(beam_sz - 1, k) |
|
word_probs[beam_idx::beam_sz, j] = score |
|
else: |
|
for j in range(beam_sz): |
|
word_probs[j::beam_sz, eos_idx] = valid_score_dist[0] |
|
|
|
attns = torch.randn(batch_sz * beam_sz, 1, 53) |
|
beam.advance(word_probs, attns) |
|
if i < min_length: |
|
self.assertFalse(beam.done) |
|
|
|
for b in range(batch_sz): |
|
self.assertEqual(beam.attention[b], []) |
|
elif i == min_length: |
|
|
|
self.assertTrue(beam.is_finished[:, 1].all()) |
|
beam.update_finished() |
|
self.assertFalse(beam.done) |
|
|
|
for b in range(batch_sz): |
|
self.assertEqual(beam.attention[b], []) |
|
else: |
|
|
|
self.assertTrue(beam.is_finished[:, 0].all()) |
|
beam.update_finished() |
|
self.assertTrue(beam.done) |
|
|
|
for b in range(batch_sz): |
|
|
|
self.assertEqual(len(beam.attention[b]), 2) |
|
for k in range(2): |
|
|
|
self.assertEqual(beam.attention[b][k].shape[-1], inp_lens[b]) |
|
|
|
|
|
self.assertEqual(beam.attention[b][0].shape[0], i + 1) |
|
|
|
self.assertEqual(beam.attention[b][1].shape[0], i) |
|
|
|
break |
|
|
|
|
|
class TestBeamSearchAgainstReferenceCase(unittest.TestCase): |
|
|
|
|
|
BEAM_SZ = 5 |
|
EOS_IDX = 2 |
|
N_WORDS = 8 |
|
N_BEST = 3 |
|
DEAD_SCORE = -1e20 |
|
BATCH_SZ = 3 |
|
INP_SEQ_LEN = 53 |
|
|
|
def random_attn(self): |
|
return torch.randn(self.BATCH_SZ * self.BEAM_SZ, 1, self.INP_SEQ_LEN) |
|
|
|
def init_step(self, beam, expected_len_pen): |
|
|
|
init_scores = torch.log_softmax( |
|
torch.tensor([[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1 |
|
) |
|
init_scores = deepcopy(init_scores.repeat(self.BATCH_SZ * self.BEAM_SZ, 1)) |
|
new_scores = init_scores + beam.topk_log_probs.view(-1).unsqueeze(1) |
|
expected_beam_scores, expected_preds_0 = new_scores.view( |
|
self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS |
|
).topk(self.BEAM_SZ, dim=-1) |
|
beam.advance(deepcopy(init_scores), self.random_attn()) |
|
self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores)) |
|
self.assertTrue(beam.topk_ids.equal(expected_preds_0)) |
|
self.assertFalse(beam.is_finished.any()) |
|
self.assertFalse(beam.done) |
|
return expected_beam_scores |
|
|
|
def first_step(self, beam, expected_beam_scores, expected_len_pen): |
|
|
|
assert beam.is_finished.sum() == 0 |
|
scores_1 = torch.log_softmax( |
|
torch.tensor( |
|
[ |
|
[0, 0, 0, 0.3, 0, 0.51, 0.2, 0], |
|
[0, 0, 1.5, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0.49, 0.48, 0, 0], |
|
[0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2], |
|
[0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2], |
|
] |
|
), |
|
dim=1, |
|
) |
|
scores_1 = scores_1.repeat(self.BATCH_SZ, 1) |
|
|
|
beam.advance(deepcopy(scores_1), self.random_attn()) |
|
|
|
new_scores = scores_1 + expected_beam_scores.view(-1).unsqueeze(1) |
|
expected_beam_scores, unreduced_preds = new_scores.view( |
|
self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS |
|
).topk(self.BEAM_SZ, -1) |
|
expected_bptr_1 = torch.div( |
|
unreduced_preds, self.N_WORDS, rounding_mode="trunc" |
|
) |
|
|
|
expected_preds_1 = unreduced_preds - expected_bptr_1 * self.N_WORDS |
|
self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores)) |
|
self.assertTrue( |
|
beam.topk_scores.allclose(expected_beam_scores / expected_len_pen) |
|
) |
|
self.assertTrue(beam.topk_ids.equal(expected_preds_1)) |
|
self.assertTrue(beam.current_backptr.equal(expected_bptr_1)) |
|
self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ) |
|
self.assertTrue(beam.is_finished[:, 2].all()) |
|
beam.update_finished() |
|
self.assertFalse(beam.top_beam_finished.any()) |
|
self.assertFalse(beam.done) |
|
return expected_beam_scores |
|
|
|
def second_step(self, beam, expected_beam_scores, expected_len_pen): |
|
|
|
scores_2 = torch.log_softmax( |
|
torch.tensor( |
|
[ |
|
[0, 0, 0, 0.3, 0, 0.51, 0.2, 0], |
|
[0, 0, 0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 5000, 0.48, 0, 0], |
|
[0, 0, 50, 0.2, 0.2, 0.2, 0.2, 0.2], |
|
[0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2], |
|
] |
|
), |
|
dim=1, |
|
) |
|
scores_2 = scores_2.repeat(self.BATCH_SZ, 1) |
|
|
|
beam.advance(deepcopy(scores_2), self.random_attn()) |
|
|
|
|
|
expected_beam_scores[:, 2 :: self.BEAM_SZ] = self.DEAD_SCORE |
|
new_scores = scores_2 + expected_beam_scores.view(-1).unsqueeze(1) |
|
expected_beam_scores, unreduced_preds = new_scores.view( |
|
self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS |
|
).topk(self.BEAM_SZ, -1) |
|
expected_bptr_2 = torch.div( |
|
unreduced_preds, self.N_WORDS, rounding_mode="trunc" |
|
) |
|
|
|
expected_preds_2 = unreduced_preds - expected_bptr_2 * self.N_WORDS |
|
|
|
self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores)) |
|
self.assertTrue( |
|
beam.topk_scores.allclose(expected_beam_scores / expected_len_pen) |
|
) |
|
self.assertTrue(beam.topk_ids.equal(expected_preds_2)) |
|
self.assertTrue(beam.current_backptr.equal(expected_bptr_2)) |
|
|
|
self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ) |
|
|
|
self.assertTrue(beam.is_finished[:, 0].all()) |
|
|
|
self.assertTrue(expected_bptr_2[:, 0].eq(3).all()) |
|
beam.update_finished() |
|
self.assertTrue(beam.top_beam_finished.all()) |
|
self.assertFalse(beam.done) |
|
return expected_beam_scores |
|
|
|
def third_step(self, beam, expected_beam_scores, expected_len_pen): |
|
|
|
scores_3 = torch.log_softmax( |
|
torch.tensor( |
|
[ |
|
[0, 0, 10000, 0, 5000, 0.51, 0.2, 0], |
|
[0, 0, 0, 0, 0, 0, 0, 0], |
|
[0, 0, 10000, 0, 0, 5000, 0, 0], |
|
[0, 0, 50, 0.2, 0.2, 0.2, 0.2, 0.2], |
|
[0, 0, 50, 0, 0.2, 0.2, 0.2, 0.2], |
|
] |
|
), |
|
dim=1, |
|
) |
|
scores_3 = scores_3.repeat(self.BATCH_SZ, 1) |
|
|
|
beam.advance(deepcopy(scores_3), self.random_attn()) |
|
|
|
expected_beam_scores[:, 0 :: self.BEAM_SZ] = self.DEAD_SCORE |
|
new_scores = scores_3 + expected_beam_scores.view(-1).unsqueeze(1) |
|
expected_beam_scores, unreduced_preds = new_scores.view( |
|
self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS |
|
).topk(self.BEAM_SZ, -1) |
|
expected_bptr_3 = torch.div( |
|
unreduced_preds, self.N_WORDS, rounding_mode="trunc" |
|
) |
|
|
|
expected_preds_3 = unreduced_preds - expected_bptr_3 * self.N_WORDS |
|
self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores)) |
|
self.assertTrue( |
|
beam.topk_scores.allclose(expected_beam_scores / expected_len_pen) |
|
) |
|
self.assertTrue(beam.topk_ids.equal(expected_preds_3)) |
|
self.assertTrue(beam.current_backptr.equal(expected_bptr_3)) |
|
|
|
self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ * 3) |
|
|
|
self.assertTrue(expected_bptr_3[:, 1].eq(3).all()) |
|
beam.update_finished() |
|
self.assertTrue(beam.top_beam_finished.all()) |
|
self.assertTrue(beam.done) |
|
return expected_beam_scores |
|
|
|
def test_beam_advance_against_known_reference(self): |
|
beam = BeamSearch( |
|
self.BEAM_SZ, |
|
self.BATCH_SZ, |
|
0, |
|
1, |
|
2, |
|
3, |
|
1, |
|
self.N_BEST, |
|
GlobalScorerStub(), |
|
0, |
|
30, |
|
False, |
|
0, |
|
set(), |
|
False, |
|
0.0, |
|
False, |
|
) |
|
device_init = torch.zeros(1, 1) |
|
beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ,))) |
|
expected_beam_scores = self.init_step(beam, 1) |
|
expected_beam_scores = self.first_step(beam, expected_beam_scores, 1) |
|
expected_beam_scores = self.second_step(beam, expected_beam_scores, 1) |
|
self.third_step(beam, expected_beam_scores, 1) |
|
|
|
|
|
class TestBeamWithLengthPenalty(TestBeamSearchAgainstReferenceCase): |
|
|
|
|
|
|
|
def test_beam_advance_against_known_reference(self): |
|
scorer = GNMTGlobalScorer(1.0, 0.0, "avg", "none") |
|
beam = BeamSearch( |
|
self.BEAM_SZ, |
|
self.BATCH_SZ, |
|
0, |
|
1, |
|
2, |
|
3, |
|
1, |
|
self.N_BEST, |
|
scorer, |
|
0, |
|
30, |
|
False, |
|
0, |
|
set(), |
|
False, |
|
0.0, |
|
False, |
|
) |
|
device_init = torch.zeros(1, 1) |
|
beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ,))) |
|
expected_beam_scores = self.init_step(beam, 1.0) |
|
expected_beam_scores = self.first_step(beam, expected_beam_scores, 3) |
|
expected_beam_scores = self.second_step(beam, expected_beam_scores, 4) |
|
self.third_step(beam, expected_beam_scores, 5) |
|
|
|
|
|
class TestBeamSearchLM(TestBeamSearchAgainstReferenceCase): |
|
def finish_first_beam_step(self, beam): |
|
scores_finish = torch.log_softmax( |
|
torch.tensor( |
|
[ |
|
[0, 0, 10000, 0, 5000, 0.51, 0.2, 0], |
|
[100000, 100001, 0, 0, 0, 0, 0, 0], |
|
[0, 100000, 0, 0, 0, 5000, 0, 0], |
|
[0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2], |
|
[0, 0, 0, 0, 0.2, 0.2, 0.2, 0.2], |
|
] |
|
), |
|
dim=1, |
|
) |
|
scores_finish = scores_finish.repeat(self.BATCH_SZ, 1) |
|
scores_finish[: self.BEAM_SZ, beam.eos] = 100 |
|
beam.advance(scores_finish, None) |
|
|
|
any_finished = beam.is_finished.any() |
|
if any_finished: |
|
beam.update_finished() |
|
|
|
def test_beam_lm_increase_src_len(self): |
|
beam = BeamSearchLM( |
|
self.BEAM_SZ, |
|
self.BATCH_SZ, |
|
0, |
|
1, |
|
2, |
|
3, |
|
1, |
|
self.N_BEST, |
|
GlobalScorerStub(), |
|
0, |
|
30, |
|
False, |
|
0, |
|
set(), |
|
False, |
|
0.0, |
|
False, |
|
) |
|
device_init = torch.zeros(1, 1) |
|
src_len = torch.randint(0, 30, (self.BATCH_SZ,)) |
|
fn_map_state, _, _, _ = beam.initialize(device_init, src_len) |
|
expected_beam_scores = self.init_step(beam, 1) |
|
expected_beam_scores = self.first_step(beam, expected_beam_scores, 1) |
|
expected_beam_scores = self.second_step(beam, expected_beam_scores, 1) |
|
self.third_step(beam, expected_beam_scores, 1) |
|
|
|
n_steps = beam.alive_seq.shape[-1] - 1 |
|
self.assertTrue(beam.src_len.equal(n_steps + fn_map_state(src_len, dim=0))) |
|
|
|
def test_beam_lm_update_src_len_when_finished(self): |
|
beam = BeamSearchLM( |
|
self.BEAM_SZ, |
|
self.BATCH_SZ, |
|
0, |
|
1, |
|
2, |
|
3, |
|
1, |
|
self.N_BEST, |
|
GlobalScorerStub(), |
|
0, |
|
30, |
|
False, |
|
0, |
|
set(), |
|
False, |
|
0.0, |
|
False, |
|
) |
|
device_init = torch.zeros(1, 1) |
|
src_len = torch.randint(0, 30, (self.BATCH_SZ,)) |
|
fn_map_state, _, _, _ = beam.initialize(device_init, src_len) |
|
self.init_step(beam, 1) |
|
self.finish_first_beam_step(beam) |
|
|
|
n_steps = beam.alive_seq.shape[-1] - 1 |
|
self.assertTrue(beam.src_len.equal(n_steps + fn_map_state(src_len[1:], dim=0))) |
|
|