|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
import unittest |
|
|
|
from transformers import is_torch_available |
|
from transformers.testing_utils import require_torch, slow, torch_device |
|
|
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering |
|
from transformers.generation_beam_search import BeamSearchScorer |
|
from transformers.generation_logits_process import ( |
|
ForcedBOSTokenLogitsProcessor, |
|
ForcedEOSTokenLogitsProcessor, |
|
HammingDiversityLogitsProcessor, |
|
InfNanRemoveLogitsProcessor, |
|
LogitsProcessorList, |
|
MinLengthLogitsProcessor, |
|
NoBadWordsLogitsProcessor, |
|
NoRepeatNGramLogitsProcessor, |
|
RepetitionPenaltyLogitsProcessor, |
|
TemperatureLogitsWarper, |
|
TopKLogitsWarper, |
|
TopPLogitsWarper, |
|
) |
|
from transformers.generation_stopping_criteria import MaxLengthCriteria, StoppingCriteriaList |
|
from transformers.generation_utils import ( |
|
BeamSampleDecoderOnlyOutput, |
|
BeamSampleEncoderDecoderOutput, |
|
BeamSearchDecoderOnlyOutput, |
|
BeamSearchEncoderDecoderOutput, |
|
GreedySearchDecoderOnlyOutput, |
|
GreedySearchEncoderDecoderOutput, |
|
SampleDecoderOnlyOutput, |
|
SampleEncoderDecoderOutput, |
|
) |
|
|
|
|
|
class GenerationTesterMixin: |
|
model_tester = None |
|
all_generative_model_classes = () |
|
input_name = "input_ids" |
|
|
|
def _get_input_ids_and_config(self): |
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
|
|
|
input_ids = inputs_dict[self.input_name] |
|
attention_mask = torch.ones_like(input_ids, dtype=torch.long) |
|
|
|
|
|
max_batch_size = 2 |
|
sequence_length = input_ids.shape[-1] // 2 |
|
input_ids = input_ids[:max_batch_size, :sequence_length] |
|
attention_mask = attention_mask[:max_batch_size, :sequence_length] |
|
|
|
|
|
max_length = input_ids.shape[-1] + 3 |
|
if config.eos_token_id is not None and config.pad_token_id is None: |
|
|
|
config.pad_token_id = config.eos_token_id |
|
return config, input_ids, attention_mask, max_length |
|
|
|
@staticmethod |
|
def _get_logits_processor_and_kwargs( |
|
input_length, |
|
eos_token_id, |
|
forced_bos_token_id=None, |
|
forced_eos_token_id=None, |
|
max_length=None, |
|
diversity_penalty=None, |
|
): |
|
process_kwargs = { |
|
"min_length": input_length + 1, |
|
"bad_words_ids": [[1, 0]], |
|
"no_repeat_ngram_size": 2, |
|
"repetition_penalty": 1.2, |
|
} |
|
logits_processor = LogitsProcessorList( |
|
( |
|
[ |
|
HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2), |
|
] |
|
if diversity_penalty is not None |
|
else [] |
|
) |
|
+ ( |
|
[ |
|
MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id), |
|
] |
|
if eos_token_id is not None |
|
else [] |
|
) |
|
+ ( |
|
[ |
|
ForcedBOSTokenLogitsProcessor(forced_bos_token_id), |
|
] |
|
if forced_bos_token_id is not None |
|
else [] |
|
) |
|
+ ( |
|
[ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)] |
|
if forced_eos_token_id is not None |
|
else [] |
|
) |
|
+ [ |
|
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id), |
|
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]), |
|
RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]), |
|
] |
|
) |
|
return process_kwargs, logits_processor |
|
|
|
@staticmethod |
|
def _get_warper_and_kwargs(num_beams): |
|
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7} |
|
logits_warper = LogitsProcessorList( |
|
[ |
|
TemperatureLogitsWarper(warp_kwargs["temperature"]), |
|
TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), |
|
TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), |
|
] |
|
) |
|
return warp_kwargs, logits_warper |
|
|
|
@staticmethod |
|
def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): |
|
beam_kwargs = { |
|
"early_stopping": False, |
|
"length_penalty": 2.0, |
|
"num_beams": 2, |
|
"num_return_sequences": num_return_sequences, |
|
} |
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=beam_kwargs["num_beams"], |
|
device=torch_device, |
|
length_penalty=beam_kwargs["length_penalty"], |
|
do_early_stopping=beam_kwargs["early_stopping"], |
|
num_beam_hyps_to_keep=num_return_sequences, |
|
) |
|
return beam_kwargs, beam_scorer |
|
|
|
@staticmethod |
|
def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): |
|
beam_kwargs = { |
|
"early_stopping": False, |
|
"length_penalty": 2.0, |
|
"num_beams": 2, |
|
"num_return_sequences": num_return_sequences, |
|
"num_beam_groups": 2, |
|
"diversity_penalty": 2.0, |
|
} |
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=beam_kwargs["num_beams"], |
|
device=torch_device, |
|
length_penalty=beam_kwargs["length_penalty"], |
|
do_early_stopping=beam_kwargs["early_stopping"], |
|
num_beam_hyps_to_keep=num_return_sequences, |
|
num_beam_groups=beam_kwargs["num_beam_groups"], |
|
) |
|
return beam_kwargs, beam_scorer |
|
|
|
@staticmethod |
|
def _get_encoder_outputs( |
|
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 |
|
): |
|
encoder = model.get_encoder() |
|
encoder_outputs = encoder( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( |
|
num_interleave, dim=0 |
|
) |
|
input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id() |
|
attention_mask = None |
|
return encoder_outputs, input_ids, attention_mask |
|
|
|
def _greedy_generate( |
|
self, |
|
model, |
|
input_ids, |
|
attention_mask, |
|
max_length, |
|
output_scores=False, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict_in_generate=False, |
|
): |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
|
input_ids.shape[-1], |
|
eos_token_id=model.config.eos_token_id, |
|
forced_bos_token_id=model.config.forced_bos_token_id, |
|
forced_eos_token_id=model.config.forced_eos_token_id, |
|
max_length=max_length, |
|
) |
|
|
|
kwargs = {} |
|
|
|
output_generate = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
do_sample=False, |
|
num_beams=1, |
|
max_length=max_length, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
output_scores=output_scores, |
|
return_dict_in_generate=return_dict_in_generate, |
|
remove_invalid_values=True, |
|
**logits_process_kwargs, |
|
) |
|
|
|
if model.config.is_encoder_decoder: |
|
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( |
|
model, |
|
input_ids, |
|
attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
kwargs["encoder_outputs"] = encoder_outputs |
|
|
|
with torch.no_grad(): |
|
output_greedy = model.greedy_search( |
|
input_ids, |
|
max_length=max_length, |
|
attention_mask=attention_mask, |
|
logits_processor=logits_processor, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
output_scores=output_scores, |
|
return_dict_in_generate=return_dict_in_generate, |
|
**kwargs, |
|
) |
|
return output_greedy, output_generate |
|
|
|
def _sample_generate( |
|
self, |
|
model, |
|
input_ids, |
|
attention_mask, |
|
max_length, |
|
num_return_sequences, |
|
logits_processor, |
|
logits_warper, |
|
logits_warper_kwargs, |
|
process_kwargs, |
|
output_scores=False, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict_in_generate=False, |
|
): |
|
torch.manual_seed(0) |
|
output_generate = model.generate( |
|
input_ids, |
|
do_sample=True, |
|
num_beams=1, |
|
max_length=max_length, |
|
num_return_sequences=num_return_sequences, |
|
attention_mask=attention_mask, |
|
output_scores=output_scores, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict_in_generate=return_dict_in_generate, |
|
remove_invalid_values=True, |
|
**logits_warper_kwargs, |
|
**process_kwargs, |
|
) |
|
|
|
torch.manual_seed(0) |
|
kwargs = {} |
|
if model.config.is_encoder_decoder: |
|
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( |
|
model, |
|
input_ids, |
|
attention_mask, |
|
num_interleave=num_return_sequences, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
kwargs["encoder_outputs"] = encoder_outputs |
|
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0) |
|
else: |
|
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0) |
|
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0) |
|
|
|
|
|
logits_processor.append(InfNanRemoveLogitsProcessor()) |
|
|
|
with torch.no_grad(): |
|
output_sample = model.sample( |
|
input_ids_clone, |
|
attention_mask=attention_mask_clone, |
|
max_length=max_length, |
|
logits_processor=logits_processor, |
|
logits_warper=logits_warper, |
|
output_scores=output_scores, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict_in_generate=return_dict_in_generate, |
|
**kwargs, |
|
) |
|
return output_sample, output_generate |
|
|
|
def _beam_search_generate( |
|
self, |
|
model, |
|
input_ids, |
|
attention_mask, |
|
max_length, |
|
beam_scorer, |
|
beam_kwargs, |
|
logits_processor, |
|
logits_process_kwargs, |
|
output_scores=False, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict_in_generate=False, |
|
): |
|
output_generate = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
do_sample=False, |
|
max_length=max_length, |
|
output_scores=output_scores, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict_in_generate=return_dict_in_generate, |
|
remove_invalid_values=True, |
|
**beam_kwargs, |
|
**logits_process_kwargs, |
|
) |
|
|
|
|
|
kwargs = {} |
|
if model.config.is_encoder_decoder: |
|
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( |
|
model, |
|
input_ids, |
|
attention_mask, |
|
num_interleave=beam_scorer.num_beams, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
kwargs["encoder_outputs"] = encoder_outputs |
|
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) |
|
else: |
|
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) |
|
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) |
|
|
|
with torch.no_grad(): |
|
output_beam_search = model.beam_search( |
|
input_ids_clone, |
|
beam_scorer, |
|
max_length=max_length, |
|
attention_mask=attention_mask_clone, |
|
logits_processor=logits_processor, |
|
output_scores=output_scores, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict_in_generate=return_dict_in_generate, |
|
**kwargs, |
|
) |
|
return output_generate, output_beam_search |
|
|
|
def _beam_sample_generate( |
|
self, |
|
model, |
|
input_ids, |
|
attention_mask, |
|
max_length, |
|
num_return_sequences, |
|
beam_scorer, |
|
beam_kwargs, |
|
logits_warper, |
|
logits_warper_kwargs, |
|
output_scores=False, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict_in_generate=False, |
|
): |
|
torch.manual_seed(0) |
|
output_generate = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
do_sample=True, |
|
max_length=max_length, |
|
output_scores=output_scores, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict_in_generate=return_dict_in_generate, |
|
remove_invalid_values=True, |
|
**beam_kwargs, |
|
**logits_warper_kwargs, |
|
) |
|
|
|
kwargs = {} |
|
if model.config.is_encoder_decoder: |
|
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( |
|
model, |
|
input_ids, |
|
attention_mask, |
|
num_interleave=beam_scorer.num_beams * num_return_sequences, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
kwargs["encoder_outputs"] = encoder_outputs |
|
else: |
|
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) |
|
|
|
|
|
logits_processor = LogitsProcessorList() |
|
logits_processor.append(InfNanRemoveLogitsProcessor()) |
|
|
|
torch.manual_seed(0) |
|
with torch.no_grad(): |
|
output_beam_sample = model.beam_sample( |
|
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0), |
|
beam_scorer, |
|
max_length=max_length, |
|
attention_mask=attention_mask, |
|
logits_warper=logits_warper, |
|
logits_processor=logits_processor, |
|
output_scores=output_scores, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict_in_generate=return_dict_in_generate, |
|
**kwargs, |
|
) |
|
|
|
return output_generate, output_beam_sample |
|
|
|
def _group_beam_search_generate( |
|
self, |
|
model, |
|
input_ids, |
|
attention_mask, |
|
max_length, |
|
beam_scorer, |
|
beam_kwargs, |
|
logits_processor, |
|
logits_process_kwargs, |
|
output_scores=False, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict_in_generate=False, |
|
): |
|
output_generate = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
do_sample=False, |
|
max_length=max_length, |
|
output_scores=output_scores, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict_in_generate=return_dict_in_generate, |
|
remove_invalid_values=True, |
|
**beam_kwargs, |
|
**logits_process_kwargs, |
|
) |
|
|
|
|
|
kwargs = {} |
|
if model.config.is_encoder_decoder: |
|
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( |
|
model, |
|
input_ids, |
|
attention_mask, |
|
num_interleave=beam_scorer.num_beams, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
kwargs["encoder_outputs"] = encoder_outputs |
|
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) |
|
else: |
|
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) |
|
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) |
|
|
|
with torch.no_grad(): |
|
output_group_beam_search = model.group_beam_search( |
|
input_ids_clone, |
|
beam_scorer, |
|
max_length=max_length, |
|
attention_mask=attention_mask_clone, |
|
logits_processor=logits_processor, |
|
output_scores=output_scores, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict_in_generate=return_dict_in_generate, |
|
**kwargs, |
|
) |
|
return output_generate, output_group_beam_search |
|
|
|
def test_greedy_generate(self): |
|
|
|
for model_class in self.all_generative_model_classes: |
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
|
|
model = model_class(config).to(torch_device).eval() |
|
output_greedy, output_generate = self._greedy_generate( |
|
model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length |
|
) |
|
self.assertListEqual(output_greedy.tolist(), output_generate.tolist()) |
|
|
|
def test_greedy_generate_dict_outputs(self): |
|
for model_class in self.all_generative_model_classes: |
|
|
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
config.use_cache = False |
|
model = model_class(config).to(torch_device).eval() |
|
output_greedy, output_generate = self._greedy_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
output_scores=True, |
|
output_hidden_states=True, |
|
output_attentions=True, |
|
return_dict_in_generate=True, |
|
) |
|
|
|
if model.config.is_encoder_decoder: |
|
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) |
|
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) |
|
else: |
|
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) |
|
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) |
|
|
|
self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) |
|
|
|
for output in (output_greedy, output_generate): |
|
self._check_outputs(output, input_ids, model.config) |
|
|
|
def test_greedy_generate_dict_outputs_use_cache(self): |
|
for model_class in self.all_generative_model_classes: |
|
|
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
|
|
if not hasattr(config, "use_cache"): |
|
|
|
return |
|
|
|
config.use_cache = True |
|
config.is_decoder = True |
|
model = model_class(config).to(torch_device).eval() |
|
output_greedy, output_generate = self._greedy_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
output_scores=True, |
|
output_hidden_states=True, |
|
output_attentions=True, |
|
return_dict_in_generate=True, |
|
) |
|
|
|
self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) |
|
|
|
for output in (output_greedy, output_generate): |
|
self._check_outputs(output, input_ids, model.config, use_cache=True) |
|
|
|
def test_sample_generate(self): |
|
for model_class in self.all_generative_model_classes: |
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
model = model_class(config).to(torch_device).eval() |
|
|
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
|
|
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
|
input_ids.shape[-1], |
|
model.config.eos_token_id, |
|
forced_bos_token_id=model.config.forced_bos_token_id, |
|
forced_eos_token_id=model.config.forced_eos_token_id, |
|
max_length=max_length, |
|
) |
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) |
|
|
|
|
|
output_sample, output_generate = self._sample_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
logits_processor=logits_processor, |
|
logits_warper=logits_warper, |
|
logits_warper_kwargs=logits_warper_kwargs, |
|
process_kwargs=process_kwargs, |
|
) |
|
self.assertListEqual(output_sample.tolist(), output_generate.tolist()) |
|
|
|
|
|
output_sample, output_generate = self._sample_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
num_return_sequences=3, |
|
logits_processor=logits_processor, |
|
logits_warper=logits_warper, |
|
logits_warper_kwargs=logits_warper_kwargs, |
|
process_kwargs=process_kwargs, |
|
) |
|
self.assertListEqual(output_sample.tolist(), output_generate.tolist()) |
|
|
|
def test_sample_generate_dict_output(self): |
|
for model_class in self.all_generative_model_classes: |
|
|
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
config.use_cache = False |
|
model = model_class(config).to(torch_device).eval() |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
|
|
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
|
input_ids.shape[-1], |
|
model.config.eos_token_id, |
|
forced_bos_token_id=model.config.forced_bos_token_id, |
|
forced_eos_token_id=model.config.forced_eos_token_id, |
|
max_length=max_length, |
|
) |
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) |
|
|
|
output_sample, output_generate = self._sample_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
num_return_sequences=2, |
|
logits_processor=logits_processor, |
|
logits_warper=logits_warper, |
|
logits_warper_kwargs=logits_warper_kwargs, |
|
process_kwargs=process_kwargs, |
|
output_scores=True, |
|
output_hidden_states=True, |
|
output_attentions=True, |
|
return_dict_in_generate=True, |
|
) |
|
|
|
if model.config.is_encoder_decoder: |
|
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput) |
|
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) |
|
else: |
|
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput) |
|
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) |
|
|
|
self.assertListEqual(output_generate.sequences.tolist(), output_sample.sequences.tolist()) |
|
|
|
for output in (output_sample, output_generate): |
|
self._check_outputs(output, input_ids, model.config, num_return_sequences=2) |
|
|
|
def test_beam_search_generate(self): |
|
for model_class in self.all_generative_model_classes: |
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
|
|
|
|
|
|
|
|
config.eos_token_id = None |
|
config.forced_eos_token_id = None |
|
|
|
model = model_class(config).to(torch_device).eval() |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
|
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
|
input_ids.shape[-1], |
|
config.eos_token_id, |
|
config.forced_bos_token_id, |
|
config.forced_eos_token_id, |
|
max_length, |
|
) |
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) |
|
|
|
|
|
output_generate, output_beam_search = self._beam_search_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
beam_scorer=beam_scorer, |
|
beam_kwargs=beam_kwargs, |
|
logits_process_kwargs=logits_process_kwargs, |
|
logits_processor=logits_processor, |
|
) |
|
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) |
|
|
|
|
|
num_return_sequences = 2 |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( |
|
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences |
|
) |
|
|
|
output_generate, output_beam_search = self._beam_search_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
beam_scorer=beam_scorer, |
|
beam_kwargs=beam_kwargs, |
|
logits_process_kwargs=logits_process_kwargs, |
|
logits_processor=logits_processor, |
|
) |
|
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) |
|
|
|
def test_beam_search_generate_dict_output(self): |
|
for model_class in self.all_generative_model_classes: |
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
|
|
|
|
config.use_cache = False |
|
|
|
|
|
|
|
|
|
config.eos_token_id = None |
|
config.forced_eos_token_id = None |
|
|
|
model = model_class(config).to(torch_device).eval() |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
|
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
|
input_ids.shape[-1], |
|
config.eos_token_id, |
|
config.forced_bos_token_id, |
|
config.forced_eos_token_id, |
|
max_length, |
|
) |
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) |
|
output_generate, output_beam_search = self._beam_search_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
beam_scorer=beam_scorer, |
|
beam_kwargs=beam_kwargs, |
|
logits_process_kwargs=logits_process_kwargs, |
|
logits_processor=logits_processor, |
|
output_scores=True, |
|
output_hidden_states=True, |
|
output_attentions=True, |
|
return_dict_in_generate=True, |
|
) |
|
if model.config.is_encoder_decoder: |
|
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput) |
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) |
|
else: |
|
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput) |
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) |
|
|
|
self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist()) |
|
self.assertTrue( |
|
torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3) |
|
) |
|
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) |
|
self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) |
|
|
|
for output in (output_beam_search, output_generate): |
|
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) |
|
|
|
def test_beam_search_generate_dict_outputs_use_cache(self): |
|
for model_class in self.all_generative_model_classes: |
|
|
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
|
|
|
|
|
|
|
|
config.eos_token_id = None |
|
config.forced_eos_token_id = None |
|
|
|
if not hasattr(config, "use_cache"): |
|
|
|
return |
|
|
|
model = model_class(config).to(torch_device).eval() |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
|
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
|
input_ids.shape[-1], |
|
config.eos_token_id, |
|
config.forced_bos_token_id, |
|
config.forced_eos_token_id, |
|
max_length, |
|
) |
|
|
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) |
|
|
|
config.use_cache = True |
|
config.is_decoder = True |
|
model = model_class(config).to(torch_device).eval() |
|
output_beam, output_generate = self._beam_search_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
beam_scorer=beam_scorer, |
|
beam_kwargs=beam_kwargs, |
|
logits_process_kwargs=logits_process_kwargs, |
|
logits_processor=logits_processor, |
|
output_scores=True, |
|
output_hidden_states=True, |
|
output_attentions=True, |
|
return_dict_in_generate=True, |
|
) |
|
|
|
self.assertListEqual(output_generate.sequences.tolist(), output_beam.sequences.tolist()) |
|
|
|
for output in (output_beam, output_generate): |
|
self._check_outputs( |
|
output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams |
|
) |
|
|
|
def test_beam_sample_generate(self): |
|
for model_class in self.all_generative_model_classes: |
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
|
|
|
|
|
|
|
|
config.eos_token_id = None |
|
config.forced_eos_token_id = None |
|
|
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) |
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
|
|
num_return_sequences = 2 |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( |
|
input_ids.shape[0] * num_return_sequences, max_length |
|
) |
|
beam_kwargs["num_return_sequences"] = num_return_sequences |
|
|
|
output_generate, output_beam_sample = self._beam_sample_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
num_return_sequences=num_return_sequences, |
|
beam_scorer=beam_scorer, |
|
beam_kwargs=beam_kwargs, |
|
logits_warper=logits_warper, |
|
logits_warper_kwargs=logits_warper_kwargs, |
|
) |
|
self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist()) |
|
|
|
def test_beam_sample_generate_dict_output(self): |
|
for model_class in self.all_generative_model_classes: |
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
|
|
|
|
config.use_cache = False |
|
|
|
|
|
|
|
|
|
config.eos_token_id = None |
|
config.forced_eos_token_id = None |
|
|
|
model = model_class(config).to(torch_device).eval() |
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) |
|
|
|
num_return_sequences = 2 |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( |
|
input_ids.shape[0] * num_return_sequences, max_length |
|
) |
|
beam_kwargs["num_return_sequences"] = num_return_sequences |
|
|
|
output_beam_sample, output_generate = self._beam_sample_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
num_return_sequences=num_return_sequences, |
|
beam_scorer=beam_scorer, |
|
beam_kwargs=beam_kwargs, |
|
logits_warper=logits_warper, |
|
logits_warper_kwargs=logits_warper_kwargs, |
|
output_scores=True, |
|
output_hidden_states=True, |
|
output_attentions=True, |
|
return_dict_in_generate=True, |
|
) |
|
|
|
if model.config.is_encoder_decoder: |
|
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput) |
|
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) |
|
else: |
|
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput) |
|
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) |
|
|
|
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) |
|
self.assertTrue( |
|
torch.allclose(output_generate["sequences_scores"], output_beam_sample["sequences_scores"], atol=1e-3) |
|
) |
|
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) |
|
self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) |
|
|
|
for output in (output_beam_sample, output_generate): |
|
self._check_outputs( |
|
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams |
|
) |
|
|
|
def test_generate_without_input_ids(self): |
|
config, _, _, max_length = self._get_input_ids_and_config() |
|
|
|
|
|
if config.bos_token_id is None: |
|
return |
|
|
|
for model_class in self.all_generative_model_classes: |
|
model = model_class(config).to(torch_device) |
|
model.eval() |
|
|
|
output_ids_generate = model.generate( |
|
do_sample=False, |
|
max_length=max_length, |
|
remove_invalid_values=True, |
|
) |
|
|
|
self.assertIsNotNone(output_ids_generate) |
|
|
|
def test_group_beam_search_generate(self): |
|
for model_class in self.all_generative_model_classes: |
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
|
|
|
|
|
|
|
|
config.eos_token_id = None |
|
config.forced_eos_token_id = None |
|
|
|
model = model_class(config).to(torch_device).eval() |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
|
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
|
input_ids.shape[-1], |
|
config.eos_token_id, |
|
config.forced_bos_token_id, |
|
config.forced_eos_token_id, |
|
max_length, |
|
diversity_penalty=2.0, |
|
) |
|
|
|
|
|
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) |
|
output_generate, output_group_beam_search = self._group_beam_search_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
beam_scorer=beam_scorer, |
|
beam_kwargs=beam_kwargs, |
|
logits_processor=logits_processor, |
|
logits_process_kwargs=logits_process_kwargs, |
|
) |
|
self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) |
|
|
|
|
|
num_return_sequences = 2 |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( |
|
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences |
|
) |
|
output_generate, output_group_beam_search = self._group_beam_search_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
beam_scorer=beam_scorer, |
|
beam_kwargs=beam_kwargs, |
|
logits_processor=logits_processor, |
|
logits_process_kwargs=logits_process_kwargs, |
|
) |
|
self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) |
|
|
|
def test_group_beam_search_generate_dict_output(self): |
|
for model_class in self.all_generative_model_classes: |
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
config.use_cache = False |
|
|
|
|
|
|
|
|
|
config.eos_token_id = None |
|
config.forced_eos_token_id = None |
|
|
|
model = model_class(config).to(torch_device).eval() |
|
if model.config.is_encoder_decoder: |
|
max_length = 4 |
|
|
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
|
input_ids.shape[-1], |
|
config.eos_token_id, |
|
config.forced_bos_token_id, |
|
config.forced_eos_token_id, |
|
max_length, |
|
diversity_penalty=2.0, |
|
) |
|
|
|
num_return_sequences = 1 |
|
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( |
|
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences |
|
) |
|
output_generate, output_group_beam_search = self._group_beam_search_generate( |
|
model=model, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
beam_scorer=beam_scorer, |
|
beam_kwargs=beam_kwargs, |
|
logits_processor=logits_processor, |
|
logits_process_kwargs=logits_process_kwargs, |
|
output_scores=True, |
|
output_hidden_states=True, |
|
output_attentions=True, |
|
return_dict_in_generate=True, |
|
) |
|
if model.config.is_encoder_decoder: |
|
self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput) |
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) |
|
else: |
|
self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput) |
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) |
|
|
|
self.assertListEqual(output_generate.sequences.tolist(), output_group_beam_search.sequences.tolist()) |
|
self.assertTrue( |
|
torch.allclose( |
|
output_generate["sequences_scores"], output_group_beam_search["sequences_scores"], atol=1e-3 |
|
) |
|
) |
|
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) |
|
self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) |
|
|
|
for output in (output_group_beam_search, output_generate): |
|
self._check_outputs( |
|
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams |
|
) |
|
|
|
def test_generate_with_head_masking(self): |
|
"""Test designed for encoder-decoder models to ensure the attention head masking is used.""" |
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] |
|
for model_class in self.all_generative_model_classes: |
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
|
model = model_class(config).to(torch_device) |
|
|
|
if not config.is_encoder_decoder: |
|
continue |
|
|
|
head_masking = { |
|
"head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads, device=torch_device), |
|
"decoder_head_mask": torch.zeros( |
|
config.decoder_layers, config.decoder_attention_heads, device=torch_device |
|
), |
|
"cross_attn_head_mask": torch.zeros( |
|
config.decoder_layers, config.decoder_attention_heads, device=torch_device |
|
), |
|
} |
|
|
|
signature = inspect.signature(model.forward) |
|
|
|
if not set(head_masking.keys()) < set([*signature.parameters.keys()]): |
|
continue |
|
|
|
for attn_name, (name, mask) in zip(attention_names, head_masking.items()): |
|
out = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
num_beams=1, |
|
output_attentions=True, |
|
return_dict_in_generate=True, |
|
remove_invalid_values=True, |
|
**{name: mask}, |
|
) |
|
|
|
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] |
|
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) |
|
|
|
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): |
|
batch_size, seq_length = input_ids.shape |
|
num_sequences_in_output = batch_size * num_return_sequences |
|
gen_len = ( |
|
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length |
|
) |
|
|
|
|
|
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) |
|
|
|
|
|
if config.is_encoder_decoder: |
|
|
|
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) |
|
|
|
self._check_attentions_for_generate( |
|
num_sequences_in_output, |
|
output.decoder_attentions, |
|
min_length=1, |
|
max_length=output.sequences.shape[-1], |
|
config=config, |
|
use_cache=use_cache, |
|
) |
|
else: |
|
|
|
attentions = output.attentions if not use_cache else output.attentions[1:] |
|
min_length = seq_length if not use_cache else seq_length + 1 |
|
self._check_attentions_for_generate( |
|
num_sequences_in_output, |
|
attentions=attentions, |
|
min_length=min_length, |
|
max_length=output.sequences.shape[-1], |
|
config=config, |
|
use_cache=use_cache, |
|
) |
|
|
|
|
|
if config.is_encoder_decoder: |
|
|
|
self._check_encoder_hidden_states_for_generate( |
|
output.encoder_hidden_states, batch_size, config, seq_length |
|
) |
|
|
|
|
|
self._check_hidden_states_for_generate( |
|
num_sequences_in_output, |
|
output.decoder_hidden_states, |
|
min_length=1, |
|
max_length=output.sequences.shape[-1], |
|
config=config, |
|
use_cache=use_cache, |
|
) |
|
else: |
|
|
|
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] |
|
min_length = seq_length if not use_cache else seq_length + 1 |
|
self._check_hidden_states_for_generate( |
|
num_sequences_in_output, |
|
hidden_states, |
|
min_length=min_length, |
|
max_length=output.sequences.shape[-1], |
|
config=config, |
|
use_cache=use_cache, |
|
) |
|
|
|
def _check_scores(self, batch_size, scores, length, config): |
|
expected_shape = (batch_size, config.vocab_size) |
|
self.assertIsInstance(scores, tuple) |
|
self.assertEqual(len(scores), length) |
|
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) |
|
|
|
def _check_attentions_for_generate( |
|
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 |
|
): |
|
self.assertIsInstance(attentions, tuple) |
|
self.assertListEqual( |
|
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) |
|
) |
|
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) |
|
|
|
for idx, iter_attentions in enumerate(attentions): |
|
tgt_len = min_length + idx if not use_cache else 1 |
|
src_len = min_length + idx |
|
|
|
expected_shape = ( |
|
batch_size * num_beam_groups, |
|
config.num_attention_heads, |
|
tgt_len, |
|
src_len, |
|
) |
|
|
|
self.assertListEqual( |
|
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) |
|
) |
|
|
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): |
|
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length) |
|
self.assertIsInstance(attentions, tuple) |
|
self.assertListEqual( |
|
[layer_attentions.shape for layer_attentions in attentions], |
|
[encoder_expected_shape] * len(attentions), |
|
) |
|
|
|
def _check_hidden_states_for_generate( |
|
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 |
|
): |
|
self.assertIsInstance(hidden_states, tuple) |
|
self.assertListEqual( |
|
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], |
|
[True] * len(hidden_states), |
|
) |
|
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) |
|
|
|
for idx, iter_hidden_states in enumerate(hidden_states): |
|
seq_len = min_length + idx if not use_cache else 1 |
|
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) |
|
|
|
self.assertListEqual( |
|
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], |
|
[expected_shape] * len(iter_hidden_states), |
|
) |
|
|
|
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length): |
|
encoder_expected_shape = (batch_size, seq_length, config.hidden_size) |
|
self.assertIsInstance(hidden_states, tuple) |
|
self.assertListEqual( |
|
[layer_hidden_states.shape for layer_hidden_states in hidden_states], |
|
[encoder_expected_shape] * len(hidden_states), |
|
) |
|
|
|
|
|
@require_torch |
|
class UtilsFunctionsTest(unittest.TestCase): |
|
|
|
|
|
def test_top_k_top_p_filtering(self): |
|
logits = torch.tensor( |
|
[ |
|
[ |
|
8.2220991, |
|
-0.5620044, |
|
5.23229752, |
|
4.0386393, |
|
-6.8798378, |
|
-0.54785802, |
|
-3.2012153, |
|
2.92777176, |
|
1.88171953, |
|
7.35341276, |
|
8.43207833, |
|
-9.85711836, |
|
-5.96209236, |
|
-1.13039161, |
|
-7.1115294, |
|
-0.8369633, |
|
-5.3186408, |
|
7.06427407, |
|
0.81369344, |
|
-0.82023817, |
|
-5.9179796, |
|
0.58813443, |
|
-6.99778438, |
|
4.71551189, |
|
-0.18771637, |
|
7.44020759, |
|
9.38450987, |
|
2.12662941, |
|
-9.32562038, |
|
2.35652522, |
|
], |
|
[ |
|
0.58425518, |
|
4.53139238, |
|
-5.57510464, |
|
-6.28030699, |
|
-7.19529503, |
|
-4.02122551, |
|
1.39337037, |
|
-6.06707057, |
|
1.59480517, |
|
-9.643119, |
|
0.03907799, |
|
0.67231762, |
|
-8.88206726, |
|
6.27115922, |
|
2.28520723, |
|
4.82767506, |
|
4.30421368, |
|
8.8275313, |
|
5.44029958, |
|
-4.4735794, |
|
7.38579536, |
|
-2.91051663, |
|
2.61946077, |
|
-2.5674762, |
|
-9.48959302, |
|
-4.02922645, |
|
-1.35416918, |
|
9.67702323, |
|
-5.89478553, |
|
1.85370467, |
|
], |
|
], |
|
dtype=torch.float, |
|
device=torch_device, |
|
) |
|
|
|
non_inf_expected_idx = torch.tensor( |
|
[[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]], |
|
dtype=torch.long, |
|
device=torch_device, |
|
) |
|
|
|
non_inf_expected_output = torch.tensor( |
|
[ |
|
8.2221, |
|
8.4321, |
|
7.4402, |
|
9.3845, |
|
6.2712, |
|
8.8275, |
|
7.3858, |
|
9.6770, |
|
], |
|
dtype=torch.float, |
|
device=torch_device, |
|
) |
|
|
|
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) |
|
non_inf_output = output[output != -float("inf")].to(device=torch_device) |
|
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device) |
|
|
|
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) |
|
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) |
|
|
|
|
|
@require_torch |
|
class GenerationIntegrationTests(unittest.TestCase): |
|
@slow |
|
def test_diverse_beam_search(self): |
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood. |
|
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People. |
|
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports. |
|
The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both.""" |
|
|
|
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") |
|
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device) |
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
outputs = bart_model.generate( |
|
input_ids, |
|
num_beams=4, |
|
num_return_sequences=2, |
|
num_beam_groups=4, |
|
diversity_penalty=2.0, |
|
remove_invalid_values=True, |
|
) |
|
|
|
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
self.assertListEqual( |
|
generated_text, |
|
[ |
|
"The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.", |
|
"Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.", |
|
], |
|
) |
|
|
|
def test_max_length_backward_compat_greedy(self): |
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
max_length = 20 |
|
input_ids = input_ids.expand(2, -1) |
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
|
input_ids = bart_model._prepare_decoder_input_ids_for_generation( |
|
input_ids, |
|
decoder_start_token_id=bart_model.config.decoder_start_token_id, |
|
bos_token_id=bart_model.config.bos_token_id, |
|
) |
|
|
|
with self.assertWarns(UserWarning): |
|
bart_model.greedy_search( |
|
input_ids, |
|
max_length=max_length, |
|
pad_token_id=bart_model.config.pad_token_id, |
|
eos_token_id=bart_model.config.eos_token_id, |
|
**model_kwargs, |
|
) |
|
|
|
def test_max_length_backward_compat_sample(self): |
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
max_length = 20 |
|
input_ids = input_ids.expand(2, -1) |
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
|
input_ids = bart_model._prepare_decoder_input_ids_for_generation( |
|
input_ids, |
|
decoder_start_token_id=bart_model.config.decoder_start_token_id, |
|
bos_token_id=bart_model.config.bos_token_id, |
|
) |
|
with torch.no_grad(): |
|
with self.assertWarns(UserWarning): |
|
bart_model.sample( |
|
input_ids, |
|
max_length=max_length, |
|
pad_token_id=bart_model.config.pad_token_id, |
|
eos_token_id=bart_model.config.eos_token_id, |
|
**model_kwargs, |
|
) |
|
|
|
def test_max_length_backward_compat_beam_search(self): |
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
batch_size = 1 |
|
max_length = 20 |
|
num_beams = 2 |
|
|
|
input_ids = input_ids.expand(2, -1) |
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
|
input_ids = bart_model._prepare_decoder_input_ids_for_generation( |
|
input_ids, |
|
decoder_start_token_id=bart_model.config.decoder_start_token_id, |
|
bos_token_id=bart_model.config.bos_token_id, |
|
) |
|
|
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=num_beams, |
|
device=torch_device, |
|
) |
|
with self.assertWarns(UserWarning): |
|
_ = bart_model.beam_search( |
|
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs |
|
) |
|
|
|
def test_max_length_backward_compat_group_beam_search(self): |
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
batch_size = 1 |
|
max_length = 20 |
|
num_beams = 6 |
|
num_beam_groups = 3 |
|
num_return_sequences = num_beams * batch_size |
|
|
|
input_ids = input_ids.expand(6, -1) |
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
|
input_ids = bart_model._prepare_decoder_input_ids_for_generation( |
|
input_ids, |
|
decoder_start_token_id=bart_model.config.decoder_start_token_id, |
|
bos_token_id=bart_model.config.bos_token_id, |
|
) |
|
|
|
diverse_beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=num_beams, |
|
device=torch_device, |
|
num_beam_hyps_to_keep=num_return_sequences, |
|
num_beam_groups=num_beam_groups, |
|
) |
|
with self.assertWarns(UserWarning): |
|
bart_model.group_beam_search( |
|
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs |
|
) |
|
|
|
def test_max_length_warning_if_different(self): |
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
batch_size = 1 |
|
|
|
max_length = 20 |
|
num_beams = 6 |
|
num_beam_groups = 3 |
|
num_return_sequences = num_beams * batch_size |
|
stopping_criteria_max_length = 18 |
|
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) |
|
|
|
|
|
input_ids = input_ids.expand(6, -1) |
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
|
input_ids = bart_model._prepare_decoder_input_ids_for_generation( |
|
input_ids, |
|
decoder_start_token_id=bart_model.config.decoder_start_token_id, |
|
bos_token_id=bart_model.config.bos_token_id, |
|
) |
|
|
|
with self.assertWarns(UserWarning): |
|
bart_model.greedy_search( |
|
input_ids, |
|
max_length=max_length, |
|
pad_token_id=bart_model.config.pad_token_id, |
|
stopping_criteria=stopping_criteria, |
|
eos_token_id=bart_model.config.eos_token_id, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
with self.assertWarns(UserWarning): |
|
with torch.no_grad(): |
|
bart_model.sample( |
|
input_ids, |
|
max_length=max_length, |
|
stopping_criteria=stopping_criteria, |
|
pad_token_id=bart_model.config.pad_token_id, |
|
eos_token_id=bart_model.config.eos_token_id, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=num_beams, |
|
device=torch_device, |
|
) |
|
with self.assertWarns(UserWarning): |
|
with torch.no_grad(): |
|
bart_model.beam_search( |
|
input_ids, |
|
num_beams=num_beams, |
|
stopping_criteria=stopping_criteria, |
|
max_length=max_length, |
|
beam_scorer=beam_scorer, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
diverse_beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=num_beams, |
|
device=torch_device, |
|
num_beam_hyps_to_keep=num_return_sequences, |
|
num_beam_groups=num_beam_groups, |
|
) |
|
with self.assertWarns(UserWarning): |
|
bart_model.group_beam_search( |
|
input_ids, |
|
diverse_beam_scorer, |
|
stopping_criteria=stopping_criteria, |
|
num_beams=num_beams, |
|
max_length=max_length, |
|
**model_kwargs, |
|
) |
|
|
|
def test_beam_search_warning_if_max_length_is_passed(self): |
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
|
|
|
batch_size = 1 |
|
num_beams = 3 |
|
|
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
input_ids = input_ids.expand(num_beams, -1) |
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
|
|
|
stopping_criteria_max_length = 18 |
|
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) |
|
|
|
with self.assertWarns(UserWarning): |
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=num_beams, |
|
device=torch_device, |
|
max_length=10, |
|
) |
|
|
|
generated_ids = bart_model.beam_search( |
|
input_ids, |
|
num_beams=num_beams, |
|
stopping_criteria=stopping_criteria, |
|
beam_scorer=beam_scorer, |
|
**model_kwargs, |
|
) |
|
|
|
beam_scorer_no_max_len = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=num_beams, |
|
device=torch_device, |
|
) |
|
|
|
generated_ids_no_max_len = bart_model.beam_search( |
|
input_ids, |
|
num_beams=num_beams, |
|
stopping_criteria=stopping_criteria, |
|
beam_scorer=beam_scorer_no_max_len, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist()) |
|
|
|
def test_max_new_tokens(self): |
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
self.assertEqual(list(input_ids.shape), [1, 15]) |
|
|
|
|
|
max_new_tokens = 3 |
|
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens) |
|
|
|
self.assertEqual(list(outputs.shape), [1, 4]) |
|
|
|
|
|
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens) |
|
|
|
self.assertEqual(list(outputs.shape), [1, 18]) |
|
|
|
|
|
with self.assertWarns(UserWarning): |
|
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) |
|
|