|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
from transformers import is_torch_available |
|
from transformers.testing_utils import require_torch, torch_device |
|
|
|
from ..test_modeling_common import floats_tensor, ids_tensor |
|
|
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
from transformers.generation import ( |
|
BeamHypotheses, |
|
BeamSearchScorer, |
|
ConstrainedBeamSearchScorer, |
|
DisjunctiveConstraint, |
|
PhrasalConstraint, |
|
) |
|
|
|
|
|
class BeamSearchTester: |
|
def __init__( |
|
self, |
|
parent, |
|
batch_size=3, |
|
sequence_length=10, |
|
vocab_size=99, |
|
pad_token_id=0, |
|
max_length=20, |
|
num_beams=4, |
|
length_penalty=2.0, |
|
do_early_stopping=True, |
|
num_beam_hyps_to_keep=2, |
|
): |
|
self.parent = parent |
|
self.batch_size = batch_size |
|
self.sequence_length = sequence_length |
|
self.vocab_size = vocab_size |
|
self.pad_token_id = pad_token_id |
|
self.max_length = max_length |
|
self.num_beams = num_beams |
|
self.length_penalty = length_penalty |
|
self.do_early_stopping = do_early_stopping |
|
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep |
|
|
|
|
|
self.eos_token_id = vocab_size + 1 |
|
|
|
def prepare_beam_scorer(self, **kwargs): |
|
return BeamSearchScorer( |
|
batch_size=kwargs.get("batch_size", self.batch_size), |
|
num_beams=kwargs.get("num_beams", self.num_beams), |
|
device=torch_device, |
|
length_penalty=kwargs.get("length_penalty", self.length_penalty), |
|
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping), |
|
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep), |
|
) |
|
|
|
def prepare_inputs(self): |
|
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size) |
|
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device) |
|
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device) |
|
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True) |
|
return (input_ids, next_tokens, next_indices, next_scores) |
|
|
|
def check_beam_hypotheses(self, input_ids, *args): |
|
|
|
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True) |
|
beam_hyp = beam_scorer._beam_hyps[0] |
|
|
|
self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size) |
|
|
|
|
|
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses)) |
|
|
|
|
|
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams) |
|
|
|
|
|
for beam_idx in range(self.num_beams): |
|
beam_hyp.add(input_ids[beam_idx], -10.0) |
|
|
|
|
|
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5)) |
|
|
|
|
|
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False) |
|
beam_hyp = beam_scorer._beam_hyps[0] |
|
|
|
|
|
for beam_idx in range(self.num_beams + 1): |
|
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx)) |
|
|
|
|
|
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length**beam_hyp.length_penalty)) |
|
|
|
|
|
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length)) |
|
|
|
|
|
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length)) |
|
|
|
def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores): |
|
|
|
beam_scorer = self.prepare_beam_scorer() |
|
|
|
tokens = next_tokens.clone() |
|
tokens[0, :] = self.eos_token_id |
|
|
|
with self.parent.assertRaises(ValueError): |
|
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id) |
|
|
|
|
|
beam_scorer = self.prepare_beam_scorer() |
|
|
|
tokens = next_tokens.clone() |
|
tokens[:, : self.num_beams] = self.eos_token_id |
|
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device) |
|
beam_indices = tuple(tuple(b) for b in beam_indices) |
|
beam_scorer.process( |
|
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices |
|
) |
|
|
|
self.parent.assertTrue(beam_scorer.is_done) |
|
|
|
|
|
beam_scorer = self.prepare_beam_scorer() |
|
|
|
tokens = next_tokens.clone() |
|
tokens[:, 1] = self.eos_token_id |
|
beam_outputs = beam_scorer.process( |
|
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices |
|
) |
|
output_scores = beam_outputs["next_beam_scores"] |
|
output_tokens = beam_outputs["next_beam_tokens"] |
|
output_indices = beam_outputs["next_beam_indices"] |
|
|
|
def cut_expected_tensor(tensor): |
|
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten() |
|
|
|
|
|
|
|
expected_output_tokens = cut_expected_tensor(tokens) |
|
expected_output_scores = cut_expected_tensor(next_scores) |
|
|
|
|
|
offset = torch.div( |
|
torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams, rounding_mode="floor" |
|
) |
|
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams |
|
|
|
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist()) |
|
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist()) |
|
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3)) |
|
|
|
|
|
expected_beam_indices = list(range(10)) |
|
for batch_idx in range(self.batch_size): |
|
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1] |
|
self.parent.assertListEqual( |
|
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist() |
|
) |
|
self.parent.assertListEqual( |
|
expected_beam_indices + [correct_idx], |
|
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(), |
|
) |
|
|
|
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores): |
|
|
|
max_length = self.sequence_length + 1 |
|
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False) |
|
|
|
|
|
tokens = next_tokens.clone() |
|
|
|
tokens[0, 0] = self.eos_token_id |
|
|
|
next_scores[0, 0] = 0.0 |
|
beam_outputs = beam_scorer.process( |
|
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id |
|
) |
|
output_scores = beam_outputs["next_beam_scores"] |
|
output_tokens = beam_outputs["next_beam_tokens"] |
|
output_indices = beam_outputs["next_beam_indices"] |
|
|
|
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device) |
|
beam_indices = tuple(tuple(b) for b in beam_indices) |
|
sequence_output = beam_scorer.finalize( |
|
input_ids, |
|
output_scores, |
|
output_tokens, |
|
output_indices, |
|
pad_token_id=self.pad_token_id, |
|
eos_token_id=self.eos_token_id, |
|
max_length=max_length, |
|
beam_indices=beam_indices, |
|
) |
|
|
|
sequences = sequence_output["sequences"] |
|
sequence_scores = sequence_output["sequence_scores"] |
|
|
|
|
|
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length]) |
|
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size]) |
|
|
|
|
|
self.parent.assertFalse((sequence_scores > 0).any().item()) |
|
|
|
|
|
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id) |
|
|
|
|
|
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id) |
|
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id) |
|
|
|
|
|
beam_scorer.num_beam_hyps_to_keep = self.num_beams |
|
sequence_output = beam_scorer.finalize( |
|
input_ids, |
|
output_scores, |
|
output_tokens, |
|
output_indices, |
|
pad_token_id=self.pad_token_id, |
|
eos_token_id=self.eos_token_id, |
|
max_length=max_length, |
|
beam_indices=beam_indices, |
|
) |
|
sequences = sequence_output["sequences"] |
|
sequence_scores = sequence_output["sequence_scores"] |
|
|
|
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length]) |
|
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) |
|
|
|
|
|
class ConstrainedBeamSearchTester: |
|
def __init__( |
|
self, |
|
parent, |
|
constraints=None, |
|
batch_size=3, |
|
sequence_length=10, |
|
vocab_size=99, |
|
pad_token_id=0, |
|
max_length=20, |
|
num_beams=4, |
|
length_penalty=2.0, |
|
do_early_stopping=True, |
|
num_beam_hyps_to_keep=2, |
|
): |
|
self.parent = parent |
|
self.batch_size = batch_size |
|
self.sequence_length = sequence_length |
|
self.vocab_size = vocab_size |
|
self.pad_token_id = pad_token_id |
|
self.max_length = max_length |
|
self.num_beams = num_beams |
|
self.length_penalty = length_penalty |
|
self.do_early_stopping = do_early_stopping |
|
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep |
|
|
|
if constraints is None: |
|
force_tokens = torch.randint(10, 50, (1, 2))[0].tolist() |
|
disjunctive_tokens = torch.randint(10, 50, (2, 2)).tolist() |
|
|
|
constraints = [PhrasalConstraint(force_tokens), DisjunctiveConstraint(disjunctive_tokens)] |
|
self.constraints = constraints |
|
|
|
self.eos_token_id = vocab_size + 1 |
|
|
|
def prepare_constrained_beam_scorer(self, **kwargs): |
|
return ConstrainedBeamSearchScorer( |
|
constraints=kwargs.get("constraints", self.constraints), |
|
batch_size=kwargs.get("batch_size", self.batch_size), |
|
num_beams=kwargs.get("num_beams", self.num_beams), |
|
device=torch_device, |
|
length_penalty=kwargs.get("length_penalty", self.length_penalty), |
|
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping), |
|
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep), |
|
) |
|
|
|
def prepare_inputs(self): |
|
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size) |
|
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device) |
|
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device) |
|
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True) |
|
scores_for_all_vocab, _ = ( |
|
-floats_tensor((self.batch_size * self.num_beams, self.vocab_size)).to(torch_device) |
|
).sort(descending=True) |
|
return (input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab) |
|
|
|
def check_beam_hypotheses(self, input_ids, *args): |
|
|
|
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True) |
|
beam_hyp = constrained_beam_scorer._beam_hyps[0] |
|
|
|
self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size) |
|
|
|
|
|
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses)) |
|
|
|
|
|
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams) |
|
|
|
|
|
for beam_idx in range(self.num_beams): |
|
beam_hyp.add(input_ids[beam_idx], -10.0) |
|
|
|
|
|
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5)) |
|
|
|
|
|
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False) |
|
beam_hyp = constrained_beam_scorer._beam_hyps[0] |
|
|
|
|
|
for beam_idx in range(self.num_beams + 1): |
|
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx)) |
|
|
|
|
|
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length**beam_hyp.length_penalty)) |
|
|
|
|
|
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length)) |
|
|
|
|
|
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length)) |
|
|
|
def check_constrained_beam_scorer_update( |
|
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab |
|
): |
|
|
|
constrained_beam_scorer = self.prepare_constrained_beam_scorer() |
|
stacked_token_ids = [] |
|
for constraint in self.constraints: |
|
token_ids = constraint.token_ids |
|
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids |
|
stacked_token_ids = stacked_token_ids + token_ids |
|
|
|
fulfilling_sequence = torch.LongTensor(stacked_token_ids) |
|
fulfill_len = fulfilling_sequence.size(0) |
|
input_ids[:, :fulfill_len] = fulfilling_sequence |
|
|
|
tokens = next_tokens.clone() |
|
tokens[0, :] = self.eos_token_id |
|
|
|
with self.parent.assertRaises(ValueError): |
|
constrained_beam_scorer.process( |
|
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id |
|
) |
|
|
|
|
|
constrained_beam_scorer = self.prepare_constrained_beam_scorer() |
|
|
|
tokens = next_tokens.clone() |
|
tokens[:, : self.num_beams] = self.eos_token_id |
|
constrained_beam_scorer.process( |
|
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id |
|
) |
|
|
|
self.parent.assertTrue(constrained_beam_scorer.is_done) |
|
|
|
|
|
constrained_beam_scorer = self.prepare_constrained_beam_scorer() |
|
|
|
tokens = next_tokens.clone() |
|
tokens[:, 1] = self.eos_token_id |
|
beam_outputs = constrained_beam_scorer.process( |
|
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id |
|
) |
|
output_scores = beam_outputs["next_beam_scores"] |
|
output_tokens = beam_outputs["next_beam_tokens"] |
|
output_indices = beam_outputs["next_beam_indices"] |
|
|
|
def cut_expected_tensor(tensor): |
|
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten() |
|
|
|
|
|
|
|
expected_output_tokens = cut_expected_tensor(tokens) |
|
expected_output_scores = cut_expected_tensor(next_scores) |
|
|
|
|
|
offset = torch.div( |
|
torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams, rounding_mode="floor" |
|
) |
|
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams |
|
|
|
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist()) |
|
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist()) |
|
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3)) |
|
|
|
|
|
for batch_idx in range(self.batch_size): |
|
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1] |
|
self.parent.assertListEqual( |
|
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist() |
|
) |
|
|
|
def check_constrained_beam_scorer_finalize( |
|
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab |
|
): |
|
|
|
max_length = self.sequence_length + 1 |
|
|
|
|
|
stacked_token_ids = [] |
|
for constraint in self.constraints: |
|
token_ids = constraint.token_ids |
|
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids |
|
stacked_token_ids = stacked_token_ids + token_ids |
|
|
|
fulfilling_sequence = torch.LongTensor(stacked_token_ids) |
|
|
|
fulfill_len = fulfilling_sequence.size(0) |
|
input_ids[:, :fulfill_len] = fulfilling_sequence |
|
|
|
constrained_beam_scorer = self.prepare_constrained_beam_scorer( |
|
num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False |
|
) |
|
|
|
constraints = constrained_beam_scorer.constraints |
|
|
|
tokens = next_tokens.clone() |
|
|
|
tokens[0, 0] = self.eos_token_id |
|
|
|
next_scores[0, 0] = 0.0 |
|
|
|
beam_outputs = constrained_beam_scorer.process( |
|
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id |
|
) |
|
output_scores = beam_outputs["next_beam_scores"] |
|
output_tokens = beam_outputs["next_beam_tokens"] |
|
output_indices = beam_outputs["next_beam_indices"] |
|
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
sequence_output = constrained_beam_scorer.finalize( |
|
input_ids, |
|
output_scores, |
|
output_tokens, |
|
output_indices, |
|
pad_token_id=self.pad_token_id, |
|
eos_token_id=self.eos_token_id, |
|
max_length=max_length, |
|
) |
|
|
|
sequences = sequence_output["sequences"] |
|
sequence_scores = sequence_output["sequence_scores"] |
|
|
|
|
|
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length]) |
|
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size]) |
|
|
|
|
|
self.parent.assertFalse((sequence_scores > 0).any().item()) |
|
|
|
|
|
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id) |
|
|
|
|
|
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id) |
|
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id) |
|
|
|
|
|
for output, constraint in [(s, c) for s in sequences for c in constraints]: |
|
forced_token_ids = constraint.token_ids |
|
if isinstance(forced_token_ids[0], list): |
|
|
|
flag = False |
|
for token_ids in forced_token_ids: |
|
if self._check_sequence_inside_sequence(output, token_ids): |
|
flag = True |
|
break |
|
self.parent.assertEqual(flag, True) |
|
else: |
|
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True) |
|
|
|
|
|
|
|
|
|
constrained_beam_scorer = self.prepare_constrained_beam_scorer( |
|
num_beam_hyps_to_keep=self.num_beams, length_penalty=1.0, do_early_stopping=False |
|
) |
|
|
|
sequence_output = constrained_beam_scorer.finalize( |
|
input_ids, |
|
output_scores, |
|
output_tokens, |
|
output_indices, |
|
pad_token_id=self.pad_token_id, |
|
eos_token_id=self.eos_token_id, |
|
max_length=max_length, |
|
) |
|
sequences = sequence_output["sequences"] |
|
sequence_scores = sequence_output["sequence_scores"] |
|
|
|
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length]) |
|
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) |
|
|
|
def _check_sequence_inside_sequence(self, tensor_1, tensor_2): |
|
|
|
|
|
|
|
if not isinstance(tensor_1, list): |
|
tensor_1 = tensor_1.cpu().tolist() |
|
if not isinstance(tensor_2, list): |
|
tensor_2 = tensor_2.cpu().tolist() |
|
|
|
in_order = len(tensor_1) <= len(tensor_2) |
|
longer = tensor_2 if in_order else tensor_1 |
|
shorter = tensor_1 if in_order else tensor_2 |
|
|
|
flag = False |
|
chunk_size = len(shorter) |
|
for chunk_idx in range(len(longer) - chunk_size + 1): |
|
subseq = longer[chunk_idx : chunk_idx + chunk_size] |
|
if subseq == shorter: |
|
flag = True |
|
break |
|
|
|
return flag |
|
|
|
|
|
@require_torch |
|
class BeamSearchTest(unittest.TestCase): |
|
def setUp(self): |
|
self.beam_search_tester = BeamSearchTester(self) |
|
|
|
def test_beam_hypotheses(self): |
|
inputs = self.beam_search_tester.prepare_inputs() |
|
self.beam_search_tester.check_beam_hypotheses(*inputs) |
|
|
|
def test_beam_scorer_update(self): |
|
inputs = self.beam_search_tester.prepare_inputs() |
|
self.beam_search_tester.check_beam_scorer_update(*inputs) |
|
|
|
def test_beam_scorer_finalize(self): |
|
inputs = self.beam_search_tester.prepare_inputs() |
|
self.beam_search_tester.check_beam_scores_finalize(*inputs) |
|
|
|
|
|
@require_torch |
|
class ConstrainedBeamSearchTest(unittest.TestCase): |
|
def setUp(self): |
|
self.constrained_beam_search_tester = ConstrainedBeamSearchTester(self) |
|
|
|
def test_constrained_beam_hypotheses(self): |
|
inputs = self.constrained_beam_search_tester.prepare_inputs() |
|
self.constrained_beam_search_tester.check_beam_hypotheses(*inputs) |
|
|
|
def test_constrained_beam_scorer_update(self): |
|
inputs = self.constrained_beam_search_tester.prepare_inputs() |
|
self.constrained_beam_search_tester.check_constrained_beam_scorer_update(*inputs) |
|
|
|
def test_constrained_beam_scorer_finalize(self): |
|
inputs = self.constrained_beam_search_tester.prepare_inputs() |
|
self.constrained_beam_search_tester.check_constrained_beam_scorer_finalize(*inputs) |
|
|