# coding=utf-8 # Copyright 2020 The HuggingFace Team Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a clone of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from transformers import is_torch_available from transformers.testing_utils import require_torch, torch_device from .test_modeling_common import floats_tensor, ids_tensor if is_torch_available(): import torch from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer class BeamSearchTester: def __init__( self, parent, batch_size=3, sequence_length=10, vocab_size=99, pad_token_id=0, max_length=20, num_beams=4, length_penalty=2.0, do_early_stopping=True, num_beam_hyps_to_keep=2, ): self.parent = parent self.batch_size = batch_size self.sequence_length = sequence_length self.vocab_size = vocab_size self.pad_token_id = pad_token_id self.max_length = max_length self.num_beams = num_beams self.length_penalty = length_penalty self.do_early_stopping = do_early_stopping self.num_beam_hyps_to_keep = num_beam_hyps_to_keep # cannot be randomely generated self.eos_token_id = vocab_size + 1 def prepare_beam_scorer(self, **kwargs): return BeamSearchScorer( batch_size=kwargs.get("batch_size", self.batch_size), num_beams=kwargs.get("num_beams", self.num_beams), device=torch_device, length_penalty=kwargs.get("length_penalty", self.length_penalty), do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping), num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep), ) def prepare_inputs(self): input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size) next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device) next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device) next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True) return (input_ids, next_tokens, next_indices, next_scores) def check_beam_hypotheses(self, input_ids, *args): # check that correct number of beam hypotheses is set in beam scorer beam_scorer = self.prepare_beam_scorer(do_early_stopping=True) beam_hyp = beam_scorer._beam_hyps[0] self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size) # check correct type self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses)) # check that num_beams is correctly set self.parent.assertEqual(beam_hyp.num_beams, self.num_beams) # check for early stopping deactivated for beam_idx in range(self.num_beams): beam_hyp.add(input_ids[beam_idx], -10.0) # if early stopping True -> score does not matter self.parent.assertTrue(beam_hyp.is_done(-10.0, 5)) # re-init beam_scorer = self.prepare_beam_scorer(do_early_stopping=False) beam_hyp = beam_scorer._beam_hyps[0] # add `num_beams + 1` beams to change `worst_score` for beam_idx in range(self.num_beams + 1): beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx)) # -10.0 is removed => -9.0 is worst score self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty)) # -5.0 is better than worst score => should not be finished self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length)) # -20.0 is worse than worst score => should be finished self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length)) def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores): # check too many eos tokens beam_scorer = self.prepare_beam_scorer() tokens = next_tokens.clone() tokens[0, :] = self.eos_token_id with self.parent.assertRaises(ValueError): beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id) # check all batches are done beam_scorer = self.prepare_beam_scorer() tokens = next_tokens.clone() tokens[:, : self.num_beams] = self.eos_token_id beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id) # beam scorer should be done self.parent.assertTrue(beam_scorer.is_done) # check beam_scorer = self.prepare_beam_scorer() tokens = next_tokens.clone() tokens[:, 1] = self.eos_token_id beam_outputs = beam_scorer.process( input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id ) output_scores = beam_outputs["next_beam_scores"] output_tokens = beam_outputs["next_beam_tokens"] output_indices = beam_outputs["next_beam_indices"] def cut_expected_tensor(tensor): return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten() # check all outptus # cut out id of eos token and take best `num_beams` outputs expected_output_tokens = cut_expected_tensor(tokens) expected_output_scores = cut_expected_tensor(next_scores) # add num_beams * batch_idx expected_output_indices = ( cut_expected_tensor(next_indices) + (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams ) self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist()) self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist()) self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3)) # make sure ids of eos token are correctly saved in beam_hyps of beam scorer for batch_idx in range(self.batch_size): correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1] self.parent.assertListEqual( input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist() ) def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores): # max_length should be only one more than current input_ids to check that eos is correctly appended max_length = self.sequence_length + 1 beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False) # update beams and append to input_ids tokens = next_tokens.clone() # first batch, first output has to finish with eos token id since scores are correctly sorted tokens[0, 0] = self.eos_token_id # make sure corresponding score is as good as possible to surely be picked first next_scores[0, 0] = 0.0 beam_outputs = beam_scorer.process( input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id ) output_scores = beam_outputs["next_beam_scores"] output_tokens = beam_outputs["next_beam_tokens"] output_indices = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1) # finalize sequence_output = beam_scorer.finalize( input_ids, output_scores, output_tokens, output_indices, pad_token_id=self.pad_token_id, eos_token_id=self.eos_token_id, max_length=max_length, ) sequences = sequence_output["sequences"] sequence_scores = sequence_output["sequence_scores"] # since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length` self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length]) self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size]) # check sequence_scores self.parent.assertFalse((sequence_scores > 0).any().item()) # first batch has to finish with eos_token self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id) # other batches cannot finish with eos token self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id) self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id) # now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned beam_scorer.num_beam_hyps_to_keep = self.num_beams sequence_output = beam_scorer.finalize( input_ids, output_scores, output_tokens, output_indices, pad_token_id=self.pad_token_id, eos_token_id=self.eos_token_id, max_length=max_length, ) sequences = sequence_output["sequences"] sequence_scores = sequence_output["sequence_scores"] self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length]) self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) @require_torch class BeamSearchTest(unittest.TestCase): def setUp(self): self.beam_search_tester = BeamSearchTester(self) def test_beam_hypotheses(self): inputs = self.beam_search_tester.prepare_inputs() self.beam_search_tester.check_beam_hypotheses(*inputs) def test_beam_scorer_update(self): inputs = self.beam_search_tester.prepare_inputs() self.beam_search_tester.check_beam_scorer_update(*inputs) def test_beam_scorer_finalize(self): inputs = self.beam_search_tester.prepare_inputs() self.beam_search_tester.check_beam_scores_finalize(*inputs)