|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import collections |
|
|
import copy |
|
|
import datetime |
|
|
import gc |
|
|
import inspect |
|
|
import random |
|
|
import tempfile |
|
|
import unittest |
|
|
import warnings |
|
|
|
|
|
import numpy as np |
|
|
import pytest |
|
|
from packaging import version |
|
|
from parameterized import parameterized |
|
|
|
|
|
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, is_torch_available, logging, pipeline |
|
|
from transformers.testing_utils import ( |
|
|
CaptureLogger, |
|
|
is_flaky, |
|
|
require_accelerate, |
|
|
require_flash_attn, |
|
|
require_optimum_quanto, |
|
|
require_read_token, |
|
|
require_torch, |
|
|
require_torch_accelerator, |
|
|
require_torch_gpu, |
|
|
require_torch_greater_or_equal, |
|
|
require_torch_multi_accelerator, |
|
|
require_torch_multi_gpu, |
|
|
require_torch_sdpa, |
|
|
set_config_for_less_flaky_test, |
|
|
set_model_for_less_flaky_test, |
|
|
set_model_tester_for_less_flaky_test, |
|
|
slow, |
|
|
torch_device, |
|
|
) |
|
|
from transformers.utils import is_ipex_available, is_torchdynamo_exporting |
|
|
|
|
|
|
|
|
if is_torch_available(): |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoModelForImageTextToText, |
|
|
AutoModelForSeq2SeqLM, |
|
|
AutoModelForSpeechSeq2Seq, |
|
|
AutoModelForVision2Seq, |
|
|
BartForConditionalGeneration, |
|
|
BartTokenizer, |
|
|
GPT2LMHeadModel, |
|
|
GPT2Tokenizer, |
|
|
ImageGPTForCausalImageModeling, |
|
|
SpeechEncoderDecoderModel, |
|
|
T5ForConditionalGeneration, |
|
|
) |
|
|
from transformers.cache_utils import ( |
|
|
Cache, |
|
|
DynamicCache, |
|
|
EncoderDecoderCache, |
|
|
HybridCache, |
|
|
QuantoQuantizedCache, |
|
|
StaticCache, |
|
|
) |
|
|
from transformers.generation import ( |
|
|
BeamSampleDecoderOnlyOutput, |
|
|
BeamSampleEncoderDecoderOutput, |
|
|
BeamSearchDecoderOnlyOutput, |
|
|
BeamSearchEncoderDecoderOutput, |
|
|
CompileConfig, |
|
|
DisjunctiveConstraint, |
|
|
GenerateBeamDecoderOnlyOutput, |
|
|
GenerateBeamEncoderDecoderOutput, |
|
|
GenerateDecoderOnlyOutput, |
|
|
GenerateEncoderDecoderOutput, |
|
|
GenerationConfig, |
|
|
GenerationMixin, |
|
|
GreedySearchDecoderOnlyOutput, |
|
|
GreedySearchEncoderDecoderOutput, |
|
|
LogitsProcessorList, |
|
|
MaxLengthCriteria, |
|
|
MinLengthLogitsProcessor, |
|
|
PhrasalConstraint, |
|
|
PromptLookupCandidateGenerator, |
|
|
SampleDecoderOnlyOutput, |
|
|
SampleEncoderDecoderOutput, |
|
|
StoppingCriteria, |
|
|
StoppingCriteriaList, |
|
|
SynthIDTextWatermarkingConfig, |
|
|
WatermarkDetector, |
|
|
WatermarkingConfig, |
|
|
) |
|
|
from transformers.generation.candidate_generator import ( |
|
|
AssistedCandidateGenerator, |
|
|
AssistedCandidateGeneratorDifferentTokenizers, |
|
|
) |
|
|
from transformers.generation.utils import _speculative_sampling |
|
|
|
|
|
from unittest.mock import patch |
|
|
|
|
|
from transformers.utils import is_sklearn_available |
|
|
|
|
|
|
|
|
|
|
|
VLM_CLASS_NAMES = [ |
|
|
"llava", |
|
|
"idefics2", |
|
|
"idefics3", |
|
|
"mllama", |
|
|
"paligemma", |
|
|
"emu3", |
|
|
"gotocr2", |
|
|
"qwen2vl", |
|
|
"qwen2_5_vl", |
|
|
"ayavision", |
|
|
"janus", |
|
|
"gemma3", |
|
|
"mistral3", |
|
|
"chameleon", |
|
|
"internvl", |
|
|
"qwen2_5omni", |
|
|
] |
|
|
|
|
|
|
|
|
class GenerationTesterMixin: |
|
|
input_name = "input_ids" |
|
|
model_tester = None |
|
|
max_new_tokens = 3 |
|
|
|
|
|
def prepare_config_and_inputs_for_generate(self, batch_size=2): |
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
|
|
|
|
|
|
|
|
input_keys_to_ignore = [ |
|
|
|
|
|
"head_mask", |
|
|
"decoder_head_mask", |
|
|
"cross_attn_head_mask", |
|
|
|
|
|
"decoder_input_ids", |
|
|
"decoder_attention_mask", |
|
|
|
|
|
"use_cache", |
|
|
|
|
|
"labels", |
|
|
|
|
|
] |
|
|
filtered_inputs_dict = { |
|
|
k: v[:batch_size, ...] if isinstance(v, torch.Tensor) else v |
|
|
for k, v in inputs_dict.items() |
|
|
if k not in input_keys_to_ignore |
|
|
} |
|
|
|
|
|
|
|
|
text_gen_config = config.get_text_config(decoder=True) |
|
|
if text_gen_config.eos_token_id is not None and text_gen_config.pad_token_id is None: |
|
|
text_gen_config.pad_token_id = ( |
|
|
text_gen_config.eos_token_id |
|
|
if isinstance(text_gen_config.eos_token_id, int) |
|
|
else text_gen_config.eos_token_id[0] |
|
|
) |
|
|
text_gen_config.eos_token_id = None |
|
|
text_gen_config.forced_eos_token_id = None |
|
|
|
|
|
return config, filtered_inputs_dict |
|
|
|
|
|
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5): |
|
|
""" |
|
|
Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in |
|
|
the following situations: |
|
|
1. The sequences are the same |
|
|
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical |
|
|
""" |
|
|
|
|
|
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores) |
|
|
output_matches = output_1.sequences == output_2.sequences |
|
|
has_matching_outputs = output_matches.all() |
|
|
has_matching_scores = None |
|
|
if not has_matching_outputs: |
|
|
for batch_idx in range(output_1.sequences.shape[0]): |
|
|
batch_matches = output_matches[batch_idx] |
|
|
if batch_matches.all(): |
|
|
continue |
|
|
first_mismatch_idx = batch_matches.int().argmin() |
|
|
first_mismatch_idx -= decoder_input_length |
|
|
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx] |
|
|
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx] |
|
|
has_matching_scores = torch.allclose( |
|
|
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol |
|
|
) |
|
|
if not has_matching_scores: |
|
|
break |
|
|
self.assertTrue(has_matching_outputs or has_matching_scores) |
|
|
|
|
|
def _get_logits_processor_kwargs(self, do_sample=False, config=None): |
|
|
logits_processor_kwargs = { |
|
|
"bad_words_ids": [[1, 0]], |
|
|
"repetition_penalty": 1.2, |
|
|
"remove_invalid_values": True, |
|
|
} |
|
|
if do_sample: |
|
|
logits_processor_kwargs.update( |
|
|
{ |
|
|
"top_k": 10, |
|
|
"top_p": 0.7, |
|
|
"temperature": 0.7, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if config is not None: |
|
|
for key in [ |
|
|
"image_token_id", |
|
|
"video_token_id", |
|
|
"audio_token_id", |
|
|
"vision_start_token_id", |
|
|
"audio_start_token_id", |
|
|
"audio_end_token_id", |
|
|
"vision_end_token_id", |
|
|
]: |
|
|
token_index = getattr(config, key, None) |
|
|
if token_index is None and hasattr(self, "model_tester"): |
|
|
token_index = getattr(self.model_tester, key, None) |
|
|
if token_index is not None and token_index < config.get_text_config().vocab_size: |
|
|
logits_processor_kwargs["bad_words_ids"].append([token_index]) |
|
|
|
|
|
return logits_processor_kwargs |
|
|
|
|
|
def _get_beam_kwargs(self, num_return_sequences=1): |
|
|
beam_kwargs = { |
|
|
"early_stopping": False, |
|
|
"length_penalty": 2.0, |
|
|
"num_beams": 2, |
|
|
"num_return_sequences": num_return_sequences, |
|
|
} |
|
|
return beam_kwargs |
|
|
|
|
|
def _get_diverse_beam_kwargs(self, num_return_sequences=1): |
|
|
beam_kwargs = { |
|
|
"early_stopping": False, |
|
|
"length_penalty": 2.0, |
|
|
"num_beams": 2, |
|
|
"num_return_sequences": num_return_sequences, |
|
|
"num_beam_groups": 2, |
|
|
"diversity_penalty": 2.0, |
|
|
} |
|
|
return beam_kwargs |
|
|
|
|
|
def _get_constrained_beam_kwargs(self, num_return_sequences=1): |
|
|
beam_kwargs = { |
|
|
"early_stopping": False, |
|
|
"length_penalty": 2.0, |
|
|
"num_beams": num_return_sequences * 4, |
|
|
"num_return_sequences": num_return_sequences, |
|
|
} |
|
|
return beam_kwargs |
|
|
|
|
|
def _greedy_generate( |
|
|
self, |
|
|
model, |
|
|
inputs_dict, |
|
|
output_scores=False, |
|
|
output_logits=False, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict_in_generate=False, |
|
|
use_cache=True, |
|
|
): |
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) |
|
|
output_generate = model.generate( |
|
|
do_sample=False, |
|
|
num_beams=1, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
min_new_tokens=self.max_new_tokens, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
output_scores=output_scores, |
|
|
output_logits=output_logits, |
|
|
return_dict_in_generate=return_dict_in_generate, |
|
|
use_cache=use_cache, |
|
|
**logits_processor_kwargs, |
|
|
**inputs_dict, |
|
|
) |
|
|
|
|
|
return output_generate |
|
|
|
|
|
def _sample_generate( |
|
|
self, |
|
|
model, |
|
|
inputs_dict, |
|
|
num_return_sequences, |
|
|
output_scores=False, |
|
|
output_logits=False, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict_in_generate=False, |
|
|
use_cache=True, |
|
|
): |
|
|
torch.manual_seed(0) |
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config) |
|
|
output_generate = model.generate( |
|
|
do_sample=True, |
|
|
num_beams=1, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
min_new_tokens=self.max_new_tokens, |
|
|
num_return_sequences=num_return_sequences, |
|
|
output_scores=output_scores, |
|
|
output_logits=output_logits, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict_in_generate=return_dict_in_generate, |
|
|
use_cache=use_cache, |
|
|
**logits_processor_kwargs, |
|
|
**inputs_dict, |
|
|
) |
|
|
|
|
|
return output_generate |
|
|
|
|
|
def _beam_search_generate( |
|
|
self, |
|
|
model, |
|
|
inputs_dict, |
|
|
beam_kwargs, |
|
|
output_scores=False, |
|
|
output_logits=False, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict_in_generate=False, |
|
|
use_cache=True, |
|
|
): |
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) |
|
|
output_generate = model.generate( |
|
|
do_sample=False, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
min_new_tokens=self.max_new_tokens, |
|
|
output_scores=output_scores, |
|
|
output_logits=output_logits, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict_in_generate=return_dict_in_generate, |
|
|
use_cache=use_cache, |
|
|
**beam_kwargs, |
|
|
**logits_processor_kwargs, |
|
|
**inputs_dict, |
|
|
) |
|
|
|
|
|
return output_generate |
|
|
|
|
|
def _beam_sample_generate( |
|
|
self, |
|
|
model, |
|
|
inputs_dict, |
|
|
beam_kwargs, |
|
|
output_scores=False, |
|
|
output_logits=False, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict_in_generate=False, |
|
|
use_cache=True, |
|
|
): |
|
|
torch.manual_seed(0) |
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config) |
|
|
output_generate = model.generate( |
|
|
do_sample=True, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
min_new_tokens=self.max_new_tokens, |
|
|
output_scores=output_scores, |
|
|
output_logits=output_logits, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict_in_generate=return_dict_in_generate, |
|
|
use_cache=use_cache, |
|
|
**beam_kwargs, |
|
|
**logits_processor_kwargs, |
|
|
**inputs_dict, |
|
|
) |
|
|
|
|
|
return output_generate |
|
|
|
|
|
def _group_beam_search_generate( |
|
|
self, |
|
|
model, |
|
|
inputs_dict, |
|
|
beam_kwargs, |
|
|
output_scores=False, |
|
|
output_logits=False, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict_in_generate=False, |
|
|
use_cache=True, |
|
|
): |
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) |
|
|
output_generate = model.generate( |
|
|
do_sample=False, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
min_new_tokens=self.max_new_tokens, |
|
|
output_scores=output_scores, |
|
|
output_logits=output_logits, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict_in_generate=return_dict_in_generate, |
|
|
use_cache=use_cache, |
|
|
**beam_kwargs, |
|
|
**logits_processor_kwargs, |
|
|
**inputs_dict, |
|
|
) |
|
|
|
|
|
return output_generate |
|
|
|
|
|
def _constrained_beam_search_generate( |
|
|
self, |
|
|
model, |
|
|
inputs_dict, |
|
|
constraints, |
|
|
beam_kwargs, |
|
|
output_scores=False, |
|
|
output_logits=False, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict_in_generate=False, |
|
|
use_cache=True, |
|
|
): |
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) |
|
|
output_generate = model.generate( |
|
|
do_sample=False, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
min_new_tokens=self.max_new_tokens, |
|
|
output_scores=output_scores, |
|
|
output_logits=output_logits, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict_in_generate=return_dict_in_generate, |
|
|
constraints=constraints, |
|
|
use_cache=use_cache, |
|
|
**beam_kwargs, |
|
|
**logits_processor_kwargs, |
|
|
**inputs_dict, |
|
|
) |
|
|
|
|
|
return output_generate |
|
|
|
|
|
def _contrastive_generate( |
|
|
self, |
|
|
model, |
|
|
inputs_dict, |
|
|
output_scores=False, |
|
|
output_logits=False, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict_in_generate=False, |
|
|
use_cache=True, |
|
|
): |
|
|
contrastive_search_kwargs = { |
|
|
"penalty_alpha": 0.6, |
|
|
"top_k": 5, |
|
|
} |
|
|
|
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) |
|
|
output_generate = model.generate( |
|
|
do_sample=False, |
|
|
num_beams=1, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
min_new_tokens=self.max_new_tokens, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
output_scores=output_scores, |
|
|
output_logits=output_logits, |
|
|
return_dict_in_generate=return_dict_in_generate, |
|
|
use_cache=use_cache, |
|
|
**logits_processor_kwargs, |
|
|
**contrastive_search_kwargs, |
|
|
**inputs_dict, |
|
|
) |
|
|
|
|
|
return output_generate |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_greedy_generate(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_greedy_generate_dict_outputs(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
output_generate = self._greedy_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
output_scores=True, |
|
|
output_logits=True, |
|
|
output_hidden_states=True, |
|
|
output_attentions=self.has_attentions, |
|
|
return_dict_in_generate=True, |
|
|
use_cache=False, |
|
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
|
|
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) |
|
|
else: |
|
|
self.assertTrue( |
|
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] |
|
|
) |
|
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) |
|
|
|
|
|
self._check_generate_outputs(output_generate, model.config) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_greedy_generate_dict_outputs_use_cache(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
|
|
|
if not hasattr(config.get_text_config(), "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): |
|
|
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") |
|
|
|
|
|
config.is_decoder = True |
|
|
model = model_class(config).to(torch_device).eval() |
|
|
output_generate = self._greedy_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
output_scores=True, |
|
|
output_logits=True, |
|
|
output_hidden_states=True, |
|
|
output_attentions=self.has_attentions, |
|
|
return_dict_in_generate=True, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue( |
|
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] |
|
|
) |
|
|
|
|
|
self._check_generate_outputs(output_generate, model.config, use_cache=True) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_sample_generate(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_sample_generate_dict_output(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
output_generate = self._sample_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
num_return_sequences=2, |
|
|
output_scores=True, |
|
|
output_logits=True, |
|
|
output_hidden_states=True, |
|
|
output_attentions=self.has_attentions, |
|
|
return_dict_in_generate=True, |
|
|
use_cache=False, |
|
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
|
|
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) |
|
|
else: |
|
|
self.assertTrue( |
|
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] |
|
|
) |
|
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) |
|
|
|
|
|
self._check_generate_outputs(output_generate, model.config, num_return_sequences=2) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_beam_search_generate(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
beam_kwargs = self._get_beam_kwargs() |
|
|
output_generate = self._beam_search_generate(model=model, inputs_dict=inputs_dict, beam_kwargs=beam_kwargs) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_beam_search_generate_dict_output(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
beam_kwargs = self._get_beam_kwargs() |
|
|
output_generate = self._beam_search_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
beam_kwargs=beam_kwargs, |
|
|
output_scores=True, |
|
|
output_logits=True, |
|
|
output_hidden_states=True, |
|
|
output_attentions=self.has_attentions, |
|
|
return_dict_in_generate=True, |
|
|
use_cache=False, |
|
|
) |
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
|
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) |
|
|
else: |
|
|
self.assertTrue( |
|
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] |
|
|
) |
|
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) |
|
|
|
|
|
self._check_generate_outputs( |
|
|
output_generate, |
|
|
model.config, |
|
|
num_return_sequences=beam_kwargs["num_return_sequences"], |
|
|
num_beams=beam_kwargs["num_beams"], |
|
|
) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_beam_search_generate_dict_outputs_use_cache(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
if not hasattr(config.get_text_config(), "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): |
|
|
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") |
|
|
|
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
model = model_class(config).to(torch_device).eval() |
|
|
beam_kwargs = self._get_beam_kwargs() |
|
|
|
|
|
config.is_decoder = True |
|
|
model = model_class(config).to(torch_device).eval() |
|
|
output_generate = self._beam_search_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
beam_kwargs=beam_kwargs, |
|
|
output_scores=True, |
|
|
output_logits=True, |
|
|
output_hidden_states=True, |
|
|
output_attentions=self.has_attentions, |
|
|
return_dict_in_generate=True, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue( |
|
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] |
|
|
) |
|
|
|
|
|
self._check_generate_outputs( |
|
|
output_generate, |
|
|
model.config, |
|
|
use_cache=True, |
|
|
num_return_sequences=beam_kwargs["num_return_sequences"], |
|
|
num_beams=beam_kwargs["num_beams"], |
|
|
) |
|
|
|
|
|
@require_accelerate |
|
|
@require_torch_multi_accelerator |
|
|
@pytest.mark.generate |
|
|
def test_model_parallel_beam_search(self): |
|
|
if "xpu" in torch_device: |
|
|
if not (is_ipex_available("2.5") or version.parse(torch.__version__) >= version.parse("2.6")): |
|
|
self.skipTest(reason="device_map='auto' does not work with XPU devices") |
|
|
|
|
|
for model_class in self.all_generative_model_classes: |
|
|
if model_class._no_split_modules is None: |
|
|
continue |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
model = model_class(config).eval() |
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
model.cpu().save_pretrained(tmp_dir) |
|
|
new_model = model_class.from_pretrained(tmp_dir, device_map="auto") |
|
|
|
|
|
new_model.generate( |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
num_beams=2, |
|
|
**inputs_dict, |
|
|
) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_beam_sample_generate(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
beam_kwargs = self._get_beam_kwargs() |
|
|
output_generate = self._beam_sample_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
beam_kwargs=beam_kwargs, |
|
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_beam_sample_generate_dict_output(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
beam_kwargs = self._get_beam_kwargs() |
|
|
|
|
|
output_generate = self._beam_sample_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
beam_kwargs=beam_kwargs, |
|
|
output_scores=True, |
|
|
output_logits=True, |
|
|
output_hidden_states=True, |
|
|
output_attentions=self.has_attentions, |
|
|
return_dict_in_generate=True, |
|
|
use_cache=False, |
|
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
|
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) |
|
|
else: |
|
|
self.assertTrue( |
|
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] |
|
|
) |
|
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) |
|
|
|
|
|
self._check_generate_outputs( |
|
|
output_generate, |
|
|
model.config, |
|
|
num_return_sequences=beam_kwargs["num_return_sequences"], |
|
|
num_beams=beam_kwargs["num_beams"], |
|
|
) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_generate_without_input_ids(self): |
|
|
config, _ = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
|
|
|
if config.bos_token_id is None: |
|
|
self.skipTest(reason="bos_token_id is None") |
|
|
|
|
|
|
|
|
if config.bos_token_id == config.pad_token_id: |
|
|
config.pad_token_id = None |
|
|
|
|
|
for model_class in self.all_generative_model_classes: |
|
|
model = model_class(config).to(torch_device) |
|
|
model.eval() |
|
|
|
|
|
output_ids_generate = model.generate( |
|
|
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True |
|
|
) |
|
|
self.assertIsNotNone(output_ids_generate) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_group_beam_search_generate(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
beam_kwargs = self._get_diverse_beam_kwargs() |
|
|
output_generate = self._group_beam_search_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
beam_kwargs=beam_kwargs, |
|
|
) |
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) |
|
|
|
|
|
|
|
|
num_return_sequences = 2 |
|
|
beam_kwargs = self._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences) |
|
|
output_generate = self._group_beam_search_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
beam_kwargs=beam_kwargs, |
|
|
) |
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_group_beam_search_generate_dict_output(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
beam_kwargs = self._get_diverse_beam_kwargs() |
|
|
output_generate = self._group_beam_search_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
beam_kwargs=beam_kwargs, |
|
|
output_scores=True, |
|
|
output_logits=True, |
|
|
output_hidden_states=True, |
|
|
output_attentions=self.has_attentions, |
|
|
return_dict_in_generate=True, |
|
|
use_cache=False, |
|
|
) |
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
|
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) |
|
|
else: |
|
|
self.assertTrue( |
|
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] |
|
|
) |
|
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) |
|
|
|
|
|
self._check_generate_outputs( |
|
|
output_generate, |
|
|
model.config, |
|
|
num_return_sequences=beam_kwargs["num_return_sequences"], |
|
|
num_beams=beam_kwargs["num_beams"], |
|
|
) |
|
|
|
|
|
@is_flaky() |
|
|
@pytest.mark.generate |
|
|
def test_constrained_beam_search_generate(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
|
|
|
min_id = 3 |
|
|
max_id = config.get_text_config(decoder=True).vocab_size |
|
|
|
|
|
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] |
|
|
constraints = [ |
|
|
PhrasalConstraint(force_tokens), |
|
|
] |
|
|
|
|
|
beam_kwargs = self._get_constrained_beam_kwargs() |
|
|
output_generate = self._constrained_beam_search_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
constraints=constraints, |
|
|
beam_kwargs=beam_kwargs, |
|
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) |
|
|
|
|
|
for generation_output in output_generate: |
|
|
self._check_sequence_inside_sequence(force_tokens, generation_output) |
|
|
|
|
|
|
|
|
|
|
|
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] |
|
|
constraints = [ |
|
|
PhrasalConstraint(force_tokens), |
|
|
] |
|
|
|
|
|
beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2) |
|
|
|
|
|
output_generate = self._constrained_beam_search_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
constraints=constraints, |
|
|
beam_kwargs=beam_kwargs, |
|
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) |
|
|
|
|
|
for generation_output in output_generate: |
|
|
self._check_sequence_inside_sequence(force_tokens, generation_output) |
|
|
|
|
|
@is_flaky() |
|
|
@pytest.mark.generate |
|
|
def test_constrained_beam_search_generate_dict_output(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
|
|
|
min_id = 3 |
|
|
max_id = model.config.get_text_config(decoder=True).vocab_size |
|
|
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] |
|
|
constraints = [ |
|
|
PhrasalConstraint(force_tokens), |
|
|
] |
|
|
|
|
|
beam_kwargs = self._get_constrained_beam_kwargs() |
|
|
output_generate = self._constrained_beam_search_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
constraints=constraints, |
|
|
beam_kwargs=beam_kwargs, |
|
|
output_scores=True, |
|
|
output_logits=True, |
|
|
output_hidden_states=True, |
|
|
output_attentions=self.has_attentions, |
|
|
return_dict_in_generate=True, |
|
|
use_cache=False, |
|
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
|
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) |
|
|
else: |
|
|
self.assertTrue( |
|
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] |
|
|
) |
|
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) |
|
|
|
|
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) |
|
|
|
|
|
self._check_generate_outputs( |
|
|
output_generate, |
|
|
model.config, |
|
|
num_return_sequences=beam_kwargs["num_return_sequences"], |
|
|
num_beams=beam_kwargs["num_beams"], |
|
|
) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_contrastive_generate(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
if model_class._is_stateful: |
|
|
self.skipTest(reason="Stateful models don't support contrastive search generation") |
|
|
|
|
|
|
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): |
|
|
self.skipTest(reason="Won't fix: old model with different cache format") |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
|
|
|
if not hasattr(config.get_text_config(), "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
config.is_decoder = True |
|
|
|
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
output_generate = self._contrastive_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
use_cache=True, |
|
|
) |
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_contrastive_generate_dict_outputs_use_cache(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
if model_class._is_stateful: |
|
|
self.skipTest(reason="Stateful models don't support contrastive search generation") |
|
|
|
|
|
|
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): |
|
|
self.skipTest(reason="Won't fix: old model with different cache format") |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
|
|
|
if not hasattr(config.get_text_config(), "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
config.is_decoder = True |
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
output_generate = self._contrastive_generate( |
|
|
model=model, |
|
|
inputs_dict=inputs_dict, |
|
|
output_scores=True, |
|
|
output_logits=True, |
|
|
output_hidden_states=True, |
|
|
output_attentions=self.has_attentions, |
|
|
return_dict_in_generate=True, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
|
|
else: |
|
|
self.assertTrue( |
|
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] |
|
|
) |
|
|
|
|
|
self._check_generate_outputs(output_generate, model.config, use_cache=True) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_contrastive_generate_low_memory(self): |
|
|
|
|
|
for model_class in self.all_generative_model_classes: |
|
|
if model_class._is_stateful: |
|
|
self.skipTest(reason="Stateful models don't support contrastive search generation") |
|
|
|
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]): |
|
|
self.skipTest(reason="Won't fix: old model with different cache format") |
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]): |
|
|
self.skipTest(reason="TODO: fix me") |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) |
|
|
|
|
|
|
|
|
if not hasattr(config.get_text_config(), "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
|
|
|
config.is_decoder = True |
|
|
|
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
low_output = model.generate( |
|
|
top_k=4, |
|
|
penalty_alpha=0.6, |
|
|
low_memory=True, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
**inputs_dict, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
high_output = model.generate( |
|
|
top_k=4, |
|
|
penalty_alpha=0.6, |
|
|
low_memory=False, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
**inputs_dict, |
|
|
use_cache=True, |
|
|
) |
|
|
self.assertListEqual(low_output.tolist(), high_output.tolist()) |
|
|
|
|
|
@parameterized.expand([("random",), ("same",)]) |
|
|
@pytest.mark.generate |
|
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for model_class in self.all_generative_model_classes: |
|
|
if model_class._is_stateful: |
|
|
self.skipTest(reason="Stateful models don't support assisted generation") |
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): |
|
|
self.skipTest(reason="Won't fix: old model with different cache format") |
|
|
if any( |
|
|
model_name in model_class.__name__.lower() |
|
|
for model_name in [ |
|
|
"bigbirdpegasus", |
|
|
"led", |
|
|
"mega", |
|
|
"moshi", |
|
|
"speech2text", |
|
|
"git", |
|
|
"prophetnet", |
|
|
"seamlessm4t", |
|
|
"clvp", |
|
|
"mllama", |
|
|
"blip2", |
|
|
"instructblip", |
|
|
"instructblipvideo", |
|
|
] |
|
|
): |
|
|
self.skipTest(reason="May fix in the future: need model-specific fixes") |
|
|
|
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) |
|
|
|
|
|
|
|
|
if not hasattr(config.get_text_config(), "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
|
|
|
config.is_decoder = True |
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
"eos_token_id": -1, |
|
|
"max_new_tokens": 4, |
|
|
"num_beams": 1, |
|
|
"do_sample": False, |
|
|
"output_scores": True, |
|
|
"output_logits": True, |
|
|
"output_hidden_states": True, |
|
|
"output_attentions": self.has_attentions, |
|
|
"return_dict_in_generate": True, |
|
|
"use_cache": True, |
|
|
} |
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) |
|
|
|
|
|
output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if assistant_type == "random": |
|
|
assistant_model = model_class(config).to(torch_device).eval() |
|
|
else: |
|
|
assistant_model = model |
|
|
assistant_model.generation_config.num_assistant_tokens = 2 |
|
|
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" |
|
|
generation_kwargs.update({"assistant_model": assistant_model}) |
|
|
output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) |
|
|
|
|
|
|
|
|
self._check_similar_generate_outputs(output_greedy, output_assisted) |
|
|
for output in (output_greedy, output_assisted): |
|
|
self._check_generate_outputs(output, model.config, use_cache=True) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_prompt_lookup_decoding_matches_greedy_search(self): |
|
|
|
|
|
|
|
|
|
|
|
for model_class in self.all_generative_model_classes: |
|
|
if model_class._is_stateful: |
|
|
self.skipTest(reason="Stateful models don't support assisted generation") |
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): |
|
|
self.skipTest(reason="Won't fix: old model with different cache format") |
|
|
if any( |
|
|
model_name in model_class.__name__.lower() |
|
|
for model_name in [ |
|
|
"bigbirdpegasus", |
|
|
"led", |
|
|
"mega", |
|
|
"moshi", |
|
|
"speech2text", |
|
|
"git", |
|
|
"prophetnet", |
|
|
"seamlessm4t", |
|
|
"clvp", |
|
|
"fuyu", |
|
|
"mllama", |
|
|
"blip2", |
|
|
"instructblip", |
|
|
"instructblipvideo", |
|
|
*VLM_CLASS_NAMES, |
|
|
] |
|
|
): |
|
|
self.skipTest(reason="May fix in the future: need model-specific fixes") |
|
|
|
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) |
|
|
|
|
|
|
|
|
if not hasattr(config.get_text_config(), "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
|
|
|
config.is_decoder = True |
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
"eos_token_id": -1, |
|
|
"max_new_tokens": 4, |
|
|
"num_beams": 1, |
|
|
"do_sample": False, |
|
|
"output_scores": True, |
|
|
"output_logits": True, |
|
|
"output_hidden_states": True, |
|
|
"output_attentions": self.has_attentions, |
|
|
"return_dict_in_generate": True, |
|
|
"use_cache": True, |
|
|
} |
|
|
|
|
|
output_greedy = model.generate(**generation_kwargs, **inputs_dict) |
|
|
|
|
|
generation_kwargs.update({"prompt_lookup_num_tokens": 2}) |
|
|
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict) |
|
|
|
|
|
|
|
|
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup) |
|
|
for output in (output_greedy, output_prompt_lookup): |
|
|
self._check_generate_outputs(output, model.config, use_cache=True) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_dola_decoding_sample(self): |
|
|
|
|
|
for model_class in self.all_generative_model_classes: |
|
|
if model_class._is_stateful: |
|
|
self.skipTest(reason="Stateful models don't support DoLa decoding") |
|
|
|
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): |
|
|
self.skipTest("Skip Reformer as the lm_head input size is 2 * hidden size, adopted from Rev Nets.") |
|
|
|
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["marian", "mbart", "pegasus"]): |
|
|
self.skipTest("DoLa is not supported for models that don't return layerwise hidden states") |
|
|
|
|
|
if any(model_name == model_class.__name__ for model_name in ["LlavaNextVideoForConditionalGeneration"]): |
|
|
self.skipTest(f"DoLa is failing for {model_class.__name__}") |
|
|
|
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
|
|
|
if config.is_encoder_decoder: |
|
|
self.skipTest("DoLa is not supported for encoder-decoder models") |
|
|
config.is_decoder = True |
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
if model.get_output_embeddings() is None: |
|
|
self.skipTest("DoLa is not supported for models that don't have output embeddings") |
|
|
|
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
"eos_token_id": -1, |
|
|
"max_new_tokens": 4, |
|
|
"num_beams": 1, |
|
|
"do_sample": True, |
|
|
"output_scores": True, |
|
|
"output_logits": True, |
|
|
"output_hidden_states": True, |
|
|
"output_attentions": self.has_attentions, |
|
|
"return_dict_in_generate": True, |
|
|
"use_cache": getattr(config, "use_cache", False), |
|
|
"dola_layers": "low", |
|
|
} |
|
|
output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict) |
|
|
self._check_generate_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False)) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_assisted_decoding_sample(self): |
|
|
|
|
|
|
|
|
|
|
|
for model_class in self.all_generative_model_classes: |
|
|
if model_class._is_stateful: |
|
|
self.skipTest(reason="Stateful models don't support assisted generation") |
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): |
|
|
self.skipTest(reason="Won't fix: old model with different cache format") |
|
|
if any( |
|
|
model_name in model_class.__name__.lower() |
|
|
for model_name in [ |
|
|
"bigbirdpegasus", |
|
|
"led", |
|
|
"mega", |
|
|
"moshi", |
|
|
"speech2text", |
|
|
"git", |
|
|
"prophetnet", |
|
|
"seamlessm4t", |
|
|
"clvp", |
|
|
"mllama", |
|
|
"blip2", |
|
|
"instructblip", |
|
|
"instructblipvideo", |
|
|
] |
|
|
): |
|
|
self.skipTest(reason="May fix in the future: need model-specific fixes") |
|
|
|
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) |
|
|
|
|
|
|
|
|
if not hasattr(config.get_text_config(), "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
|
|
|
config.is_decoder = True |
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assistant_model = model |
|
|
assistant_model.generation_config.num_assistant_tokens = 2 |
|
|
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" |
|
|
generation_kwargs = { |
|
|
"eos_token_id": -1, |
|
|
"max_new_tokens": 4, |
|
|
"num_beams": 1, |
|
|
"do_sample": True, |
|
|
"assistant_model": assistant_model, |
|
|
"output_scores": True, |
|
|
"output_logits": True, |
|
|
"output_hidden_states": True, |
|
|
"output_attentions": self.has_attentions, |
|
|
"return_dict_in_generate": True, |
|
|
"use_cache": True, |
|
|
} |
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) |
|
|
output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) |
|
|
|
|
|
self._check_generate_outputs(output_assisted, config, use_cache=True) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_prompt_lookup_decoding_stops_at_eos(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = torch.randint(1, 50, (1, 10), device=torch_device) |
|
|
arbitrary_ngram = 51 |
|
|
input_ids[:, 3] = arbitrary_ngram |
|
|
input_ids[:, -1] = arbitrary_ngram |
|
|
|
|
|
eos_token_id = torch.tensor([0], device=torch_device) |
|
|
input_ids[:, 4] = eos_token_id |
|
|
|
|
|
|
|
|
candidate_generator = PromptLookupCandidateGenerator( |
|
|
eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1 |
|
|
) |
|
|
output_prompt_lookup = candidate_generator.get_candidates(input_ids)[0] |
|
|
|
|
|
|
|
|
self.assertTrue(output_prompt_lookup.shape[-1] == 10) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_generate_with_head_masking(self): |
|
|
"""Test designed for encoder-decoder models to ensure the attention head masking is used.""" |
|
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
text_config = config.get_text_config() |
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
|
|
|
|
|
|
if not text_config.is_encoder_decoder: |
|
|
continue |
|
|
model = model_class(config).to(torch_device) |
|
|
|
|
|
head_masking = { |
|
|
"head_mask": torch.zeros( |
|
|
text_config.encoder_layers, text_config.encoder_attention_heads, device=torch_device |
|
|
), |
|
|
"decoder_head_mask": torch.zeros( |
|
|
text_config.decoder_layers, text_config.decoder_attention_heads, device=torch_device |
|
|
), |
|
|
"cross_attn_head_mask": torch.zeros( |
|
|
text_config.decoder_layers, text_config.decoder_attention_heads, device=torch_device |
|
|
), |
|
|
} |
|
|
|
|
|
signature = inspect.signature(model.forward) |
|
|
|
|
|
if not set(head_masking.keys()) < {*signature.parameters.keys()}: |
|
|
continue |
|
|
|
|
|
for attn_name, (name, mask) in zip(attention_names, head_masking.items()): |
|
|
out = model.generate( |
|
|
num_beams=1, |
|
|
output_attentions=self.has_attentions, |
|
|
return_dict_in_generate=True, |
|
|
remove_invalid_values=True, |
|
|
**{name: mask}, |
|
|
**inputs_dict, |
|
|
) |
|
|
|
|
|
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] |
|
|
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_left_padding_compatibility(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(self.all_generative_model_classes) == 0: |
|
|
self.skipTest(reason="No generative architecture available for this model.") |
|
|
|
|
|
|
|
|
if not self.has_attentions: |
|
|
self.skipTest(reason="This model doesn't support padding.") |
|
|
|
|
|
|
|
|
decoder_only_classes = [] |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, _ = self.prepare_config_and_inputs_for_generate() |
|
|
if config.is_encoder_decoder: |
|
|
continue |
|
|
else: |
|
|
decoder_only_classes.append(model_class) |
|
|
if len(decoder_only_classes) == 0: |
|
|
self.skipTest(reason="No decoder-only architecture available for this model.") |
|
|
|
|
|
|
|
|
|
|
|
has_encoder_attributes = any( |
|
|
attr_name |
|
|
for attr_name in config.to_dict().keys() |
|
|
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" |
|
|
) |
|
|
if has_encoder_attributes: |
|
|
self.skipTest( |
|
|
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." |
|
|
) |
|
|
|
|
|
|
|
|
def _prepare_model_kwargs(input_ids, attention_mask, signature): |
|
|
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
if "position_ids" in signature: |
|
|
position_ids = torch.cumsum(attention_mask, dim=-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
model_kwargs["position_ids"] = position_ids |
|
|
if "cache_position" in signature: |
|
|
cache_position = torch.arange(input_ids.shape[-1], device=torch_device) |
|
|
model_kwargs["cache_position"] = cache_position |
|
|
return model_kwargs |
|
|
|
|
|
for model_class in decoder_only_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
input_ids = inputs_dict["input_ids"] |
|
|
attention_mask = inputs_dict.get("attention_mask") |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
signature = inspect.signature(model.forward).parameters.keys() |
|
|
|
|
|
|
|
|
model.generation_config.use_cache = False |
|
|
|
|
|
|
|
|
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) |
|
|
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] |
|
|
|
|
|
|
|
|
|
|
|
pad_token_id = ( |
|
|
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 |
|
|
) |
|
|
pad_size = (input_ids.shape[0], 32) |
|
|
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id |
|
|
padded_input_ids = torch.cat((padding, input_ids), dim=1) |
|
|
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) |
|
|
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) |
|
|
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] |
|
|
|
|
|
|
|
|
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_past_key_values_format(self, custom_all_cache_shapes=None): |
|
|
""" |
|
|
Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the |
|
|
expected cache shapes. |
|
|
Having a standard KV cache format is important for a consistent API (and for advanced generation methods). |
|
|
""" |
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() |
|
|
|
|
|
|
|
|
if not hasattr(config.get_text_config(), "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
|
|
|
model = model_class(config).to(torch_device) |
|
|
model = model.eval() |
|
|
if "use_cache" not in inputs: |
|
|
inputs["use_cache"] = True |
|
|
outputs = model(**inputs) |
|
|
|
|
|
if "past_key_values" not in outputs: |
|
|
self.skipTest(reason="This model doesn't return `past_key_values`") |
|
|
|
|
|
|
|
|
past_kv = outputs["past_key_values"] |
|
|
is_legacy_cache = not isinstance(past_kv, Cache) |
|
|
|
|
|
text_config = config.get_text_config() |
|
|
num_decoder_layers = ( |
|
|
getattr(text_config, "decoder_layers", None) |
|
|
or getattr(text_config, "num_decoder_layers", None) |
|
|
or text_config.num_hidden_layers |
|
|
) |
|
|
|
|
|
if custom_all_cache_shapes is None: |
|
|
num_query_attention_heads = getattr( |
|
|
text_config, "decoder_attention_heads", text_config.num_attention_heads |
|
|
) |
|
|
embed_dim = getattr(text_config, "d_model", text_config.hidden_size) |
|
|
per_head_embed_dim = embed_dim // num_query_attention_heads |
|
|
num_key_value_heads = ( |
|
|
text_config.num_key_value_heads |
|
|
if getattr(text_config, "num_key_value_heads", None) is not None |
|
|
else num_query_attention_heads |
|
|
) |
|
|
if config.is_encoder_decoder: |
|
|
encoder_num_attention_heads = ( |
|
|
text_config.encoder_attention_heads |
|
|
if hasattr(text_config, "encoder_attention_heads") |
|
|
else text_config.num_attention_heads |
|
|
) |
|
|
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads |
|
|
batch_size, seq_length = inputs["decoder_input_ids"].shape |
|
|
|
|
|
|
|
|
default_cross_attention_shape = ( |
|
|
batch_size, |
|
|
encoder_num_attention_heads, |
|
|
encoder_per_head_embed_dim, |
|
|
) |
|
|
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) |
|
|
all_cache_shapes = [ |
|
|
[ |
|
|
default_self_attention_shape, |
|
|
default_self_attention_shape, |
|
|
default_cross_attention_shape, |
|
|
default_cross_attention_shape, |
|
|
] |
|
|
for _ in range(num_decoder_layers) |
|
|
] |
|
|
else: |
|
|
batch_size, seq_length = inputs["input_ids"].shape |
|
|
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) |
|
|
all_cache_shapes = [ |
|
|
[default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers) |
|
|
] |
|
|
|
|
|
else: |
|
|
all_cache_shapes = custom_all_cache_shapes |
|
|
|
|
|
|
|
|
|
|
|
if config.is_encoder_decoder: |
|
|
num_cache_decoder_layers = ( |
|
|
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache) |
|
|
) |
|
|
self.assertEqual(num_cache_decoder_layers, num_decoder_layers) |
|
|
|
|
|
for i in range(num_decoder_layers): |
|
|
if is_legacy_cache: |
|
|
self.assertEqual(len(past_kv[0]), 4) |
|
|
|
|
|
|
|
|
self_attention_layer_key_cache = ( |
|
|
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] |
|
|
) |
|
|
self_attention_layer_value_cache = ( |
|
|
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] |
|
|
) |
|
|
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) |
|
|
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) |
|
|
|
|
|
|
|
|
cross_attention_layer_key_cache = ( |
|
|
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] |
|
|
) |
|
|
cross_attention_layer_value_cache = ( |
|
|
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] |
|
|
) |
|
|
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] |
|
|
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] |
|
|
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) |
|
|
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) |
|
|
|
|
|
|
|
|
else: |
|
|
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache) |
|
|
self.assertEqual(num_cache_decoder_layers, num_decoder_layers) |
|
|
|
|
|
for i in range(num_decoder_layers): |
|
|
if is_legacy_cache: |
|
|
self.assertEqual(len(past_kv[0]), 2) |
|
|
|
|
|
|
|
|
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] |
|
|
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] |
|
|
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) |
|
|
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) |
|
|
|
|
|
@pytest.mark.generate |
|
|
@parameterized.expand([("greedy", 1), ("beam search", 2)]) |
|
|
def test_generate_from_inputs_embeds(self, _, num_beams): |
|
|
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc""" |
|
|
|
|
|
|
|
|
for model_class in self.all_generative_model_classes: |
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
|
|
|
|
|
|
if config.is_encoder_decoder: |
|
|
continue |
|
|
config.is_decoder = True |
|
|
|
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
requires_inputs_ids = any(model_name in model_class.__name__.lower() for model_name in ["idefics"]) |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(config, "scale_embedding"): |
|
|
config.scale_embedding = False |
|
|
|
|
|
|
|
|
|
|
|
pixel_values_is_mutually_exclusive = any( |
|
|
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES |
|
|
) |
|
|
if pixel_values_is_mutually_exclusive: |
|
|
inputs_dict.pop("pixel_values", None) |
|
|
inputs_dict.pop("pixel_values_videos", None) |
|
|
inputs_dict.pop("pixel_values_images", None) |
|
|
|
|
|
|
|
|
|
|
|
if "granitespeech" in model_class.__name__.lower(): |
|
|
inputs_dict.pop("input_features", None) |
|
|
|
|
|
|
|
|
has_complex_embeds_computation = any( |
|
|
model_name in model_class.__name__.lower() for model_name in ["moshi"] |
|
|
) |
|
|
|
|
|
|
|
|
missing_attention_mask = "attention_mask" not in inputs_dict |
|
|
|
|
|
|
|
|
input_ids = inputs_dict.pop("input_ids") |
|
|
generation_kwargs = { |
|
|
"return_dict_in_generate": True, |
|
|
"output_scores": True, |
|
|
"num_beams": num_beams, |
|
|
"do_sample": False, |
|
|
"max_new_tokens": 5, |
|
|
"min_new_tokens": 5, |
|
|
} |
|
|
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict) |
|
|
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) |
|
|
|
|
|
|
|
|
|
|
|
inputs_embeds = model.get_input_embeddings()(input_ids) |
|
|
outputs_from_embeds = model.generate( |
|
|
input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict |
|
|
) |
|
|
if not has_complex_embeds_computation: |
|
|
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds) |
|
|
|
|
|
|
|
|
|
|
|
random_embeds = torch.rand_like(inputs_embeds) |
|
|
outputs_from_rand_embeds = model.generate( |
|
|
input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict |
|
|
) |
|
|
for i in range(len(outputs_from_rand_embeds.scores)): |
|
|
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) |
|
|
|
|
|
|
|
|
|
|
|
if not (requires_inputs_ids or missing_attention_mask): |
|
|
outputs_from_embeds_wo_ids = model.generate( |
|
|
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict |
|
|
) |
|
|
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :] |
|
|
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_generate_from_inputs_embeds_with_static_cache(self): |
|
|
""" |
|
|
Test that StaticCache can generate from inputs_embeds and calculates max_cache_length |
|
|
correctly in `generate()`. We force the model to not stop generation until max-length is reached |
|
|
to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache. |
|
|
""" |
|
|
for model_class in self.all_generative_model_classes: |
|
|
if not model_class._supports_static_cache: |
|
|
self.skipTest(reason="This model does not support the static cache format") |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
if config.is_encoder_decoder: |
|
|
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): |
|
|
self.skipTest(reason="This model does not support `inputs_embeds` in generation") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pixel_values_is_mutually_exclusive = any( |
|
|
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES |
|
|
) |
|
|
if pixel_values_is_mutually_exclusive: |
|
|
inputs_dict.pop("pixel_values", None) |
|
|
inputs_dict.pop("pixel_values_videos", None) |
|
|
inputs_dict.pop("pixel_values_images", None) |
|
|
|
|
|
input_ids = inputs_dict.pop("input_ids") |
|
|
|
|
|
model.config.use_cache = True |
|
|
model.config.is_decoder = True |
|
|
batch_size = input_ids.shape[0] |
|
|
max_new_tokens = 10 |
|
|
|
|
|
|
|
|
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1 |
|
|
generation_kwargs = { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"cache_implementation": "static", |
|
|
"return_dict_in_generate": True, |
|
|
} |
|
|
|
|
|
text_config = model.config.get_text_config() |
|
|
head_dim = ( |
|
|
getattr(text_config, "head_dim", None) or text_config.hidden_size // text_config.num_attention_heads |
|
|
) |
|
|
num_key_value_heads = ( |
|
|
text_config.num_attention_heads |
|
|
if getattr(text_config, "num_key_value_heads", None) is None |
|
|
else text_config.num_key_value_heads |
|
|
) |
|
|
num_hidden_layers = text_config.num_hidden_layers |
|
|
|
|
|
inputs_embeds = model.get_input_embeddings()(input_ids) |
|
|
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict) |
|
|
|
|
|
|
|
|
|
|
|
max_length = max_new_tokens + inputs_embeds.shape[1] - 1 |
|
|
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim] |
|
|
self.assertIsInstance(outputs.past_key_values, StaticCache) |
|
|
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers) |
|
|
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_generate_continue_from_past_key_values(self): |
|
|
|
|
|
for model_class in self.all_generative_model_classes: |
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]): |
|
|
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") |
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): |
|
|
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") |
|
|
|
|
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() |
|
|
|
|
|
if not hasattr(config.get_text_config(), "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "token_type_ids" in inputs: |
|
|
del inputs["token_type_ids"] |
|
|
|
|
|
model = model_class(config).to(torch_device) |
|
|
model.eval() |
|
|
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 |
|
|
model.generation_config.forced_eos_token_id = None |
|
|
model.generation_config.encoder_no_repeat_ngram_size = 0 |
|
|
model.generation_config.use_cache = True |
|
|
|
|
|
|
|
|
outputs = model(**inputs) |
|
|
if "past_key_values" not in outputs: |
|
|
self.skipTest(reason="This model doesn't return `past_key_values`") |
|
|
|
|
|
|
|
|
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) |
|
|
|
|
|
|
|
|
|
|
|
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) |
|
|
|
|
|
|
|
|
inputs["past_key_values"] = outputs_cached.past_key_values |
|
|
new_attention_len = outputs_cached.sequences.shape[-1] |
|
|
if config.is_encoder_decoder: |
|
|
inputs["decoder_input_ids"] = outputs_cached.sequences |
|
|
if "decoder_attention_mask" in inputs: |
|
|
inputs["decoder_attention_mask"] = torch.nn.functional.pad( |
|
|
inputs["decoder_attention_mask"], |
|
|
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]), |
|
|
mode="constant", |
|
|
value=1, |
|
|
) |
|
|
else: |
|
|
inputs["input_ids"] = outputs_cached.sequences |
|
|
if "attention_mask" in inputs: |
|
|
inputs["attention_mask"] = torch.nn.functional.pad( |
|
|
inputs["attention_mask"], |
|
|
(0, new_attention_len - inputs["attention_mask"].shape[1]), |
|
|
mode="constant", |
|
|
value=1, |
|
|
) |
|
|
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True) |
|
|
|
|
|
|
|
|
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist()) |
|
|
for layer_idx in range(len(outputs_cached.past_key_values)): |
|
|
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): |
|
|
self.assertTrue( |
|
|
torch.allclose( |
|
|
outputs.past_key_values[layer_idx][kv_idx], |
|
|
outputs_cached.past_key_values[layer_idx][kv_idx], |
|
|
) |
|
|
) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_generate_continue_from_inputs_embeds(self): |
|
|
"""Tests that we can continue generation from `inputs_embeds` and past key values returned from a previous `generate` call.""" |
|
|
for model_class in self.all_generative_model_classes: |
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): |
|
|
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") |
|
|
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): |
|
|
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
if "token_type_ids" in inputs_dict: |
|
|
del inputs_dict["token_type_ids"] |
|
|
|
|
|
if config.is_encoder_decoder: |
|
|
self.skipTest(reason="This model is encoder-decoder") |
|
|
if not hasattr(config, "use_cache"): |
|
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): |
|
|
self.skipTest(reason="This model does not support `inputs_embeds` in generation") |
|
|
|
|
|
|
|
|
outputs = model(**inputs_dict) |
|
|
if "past_key_values" not in outputs: |
|
|
self.skipTest(reason="This model doesn't return `past_key_values`") |
|
|
|
|
|
pixel_values_is_mutually_exclusive = any( |
|
|
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES |
|
|
) |
|
|
if pixel_values_is_mutually_exclusive: |
|
|
inputs_dict.pop("pixel_values", None) |
|
|
inputs_dict.pop("pixel_values_videos", None) |
|
|
inputs_dict.pop("pixel_values_images", None) |
|
|
|
|
|
input_ids = inputs_dict.pop("input_ids") |
|
|
|
|
|
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 |
|
|
model.generation_config.forced_eos_token_id = None |
|
|
model.config.is_decoder = True |
|
|
model.generation_config.use_cache = True |
|
|
|
|
|
generation_kwargs = { |
|
|
"return_dict_in_generate": True, |
|
|
"do_sample": False, |
|
|
} |
|
|
|
|
|
|
|
|
input_embeds = model.get_input_embeddings()(input_ids) |
|
|
outputs = model.generate(inputs_embeds=input_embeds, max_new_tokens=4, **generation_kwargs) |
|
|
|
|
|
|
|
|
initial_output = model.generate(inputs_embeds=input_embeds, max_new_tokens=3, **generation_kwargs) |
|
|
continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1) |
|
|
cached_output = model.generate( |
|
|
inputs_embeds=continued_embeds, |
|
|
max_new_tokens=1, |
|
|
past_key_values=initial_output.past_key_values, |
|
|
**generation_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1) |
|
|
self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist()) |
|
|
|
|
|
for layer_idx in range(len(cached_output.past_key_values)): |
|
|
for kv_idx in range(len(cached_output.past_key_values[layer_idx])): |
|
|
self.assertTrue( |
|
|
torch.allclose( |
|
|
outputs.past_key_values[layer_idx][kv_idx], |
|
|
cached_output.past_key_values[layer_idx][kv_idx], |
|
|
) |
|
|
) |
|
|
|
|
|
@parameterized.expand([("offloaded",)]) |
|
|
@require_torch_accelerator |
|
|
@pytest.mark.generate |
|
|
def test_offloaded_cache_implementation(self, cache_implementation): |
|
|
"""Tests we can generate by indicating `cache_implementation` for each possible cache class""" |
|
|
for model_class in self.all_generative_model_classes: |
|
|
if not model_class._supports_cache_class: |
|
|
self.skipTest(reason="This model does not support the new cache format") |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
generation_kwargs = { |
|
|
"max_new_tokens": 5, |
|
|
"use_cache": True, |
|
|
"cache_implementation": cache_implementation, |
|
|
} |
|
|
|
|
|
legacy_results = model.generate(**generation_kwargs, **inputs_dict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_results = model.generate(**generation_kwargs, **inputs_dict) |
|
|
self.assertListEqual(legacy_results.tolist(), new_results.tolist()) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_generate_with_static_cache(self): |
|
|
""" |
|
|
Tests that generating with static cache give almost same results as with dynamic cache, and the output cache |
|
|
has the expected shapes |
|
|
""" |
|
|
set_model_tester_for_less_flaky_test(self) |
|
|
for model_class in self.all_generative_model_classes: |
|
|
if not model_class._supports_static_cache: |
|
|
self.skipTest(reason="This model does not support the static cache format") |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
set_config_for_less_flaky_test(config) |
|
|
main_input = inputs_dict[model_class.main_input_name] |
|
|
|
|
|
if config.is_encoder_decoder: |
|
|
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") |
|
|
|
|
|
config.is_decoder = True |
|
|
batch_size = main_input.shape[0] |
|
|
seq_length = self.model_tester.seq_length |
|
|
max_new_tokens = 20 |
|
|
|
|
|
for dtype in (torch.float32, torch.float16): |
|
|
model = model_class(config).to(torch_device).to(dtype).eval() |
|
|
inputs_dict = { |
|
|
k: v.to(dtype) if isinstance(v, torch.Tensor) and torch.is_floating_point(v) else v |
|
|
for k, v in inputs_dict.items() |
|
|
} |
|
|
set_model_for_less_flaky_test(model) |
|
|
|
|
|
generation_kwargs = { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"return_dict_in_generate": True, |
|
|
"output_scores": True, |
|
|
"use_cache": True, |
|
|
} |
|
|
|
|
|
static_cache_generation = model.generate( |
|
|
**generation_kwargs, **inputs_dict, cache_implementation="static" |
|
|
) |
|
|
|
|
|
|
|
|
max_cache_len = seq_length + max_new_tokens - 1 |
|
|
text_config = config.text_config if hasattr(config, "text_config") else config |
|
|
head_dim = ( |
|
|
getattr(text_config, "head_dim", None) |
|
|
or text_config.hidden_size // text_config.num_attention_heads |
|
|
) |
|
|
num_key_value_heads = ( |
|
|
text_config.num_attention_heads |
|
|
if getattr(text_config, "num_key_value_heads", None) is None |
|
|
else text_config.num_key_value_heads |
|
|
) |
|
|
num_hidden_layers = text_config.num_hidden_layers |
|
|
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) |
|
|
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) |
|
|
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers) |
|
|
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape) |
|
|
|
|
|
|
|
|
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) |
|
|
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation) |
|
|
|
|
|
@require_optimum_quanto |
|
|
@pytest.mark.generate |
|
|
def test_generate_with_quant_cache(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
if not model_class._supports_quantized_cache: |
|
|
self.skipTest(reason="This model does not support the quantized cache format") |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
config.is_decoder = True |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
generation_kwargs = { |
|
|
"max_new_tokens": 5, |
|
|
"cache_implementation": "quantized", |
|
|
|
|
|
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128}, |
|
|
"return_dict_in_generate": True, |
|
|
"use_cache": True, |
|
|
} |
|
|
|
|
|
results = model.generate(**generation_kwargs, **inputs_dict) |
|
|
self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache)) |
|
|
|
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
model.generate(past_key_valyes=DynamicCache(), **generation_kwargs, **inputs_dict) |
|
|
|
|
|
|
|
|
generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128} |
|
|
with self.assertRaises(ValueError): |
|
|
model.generate(**generation_kwargs, **inputs_dict) |
|
|
|
|
|
@pytest.mark.generate |
|
|
@require_torch_greater_or_equal("2.6") |
|
|
def test_generate_compile_model_forward(self): |
|
|
""" |
|
|
Tests that `.generate` is compatible with torch.compile, keeping the same results. Also confirms that |
|
|
`.forward` called from `.generate` sees no graph breaks or recompilations when compiled. |
|
|
|
|
|
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ |
|
|
""" |
|
|
for model_class in self.all_generative_model_classes: |
|
|
|
|
|
if not model_class._supports_static_cache: |
|
|
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") |
|
|
|
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4) |
|
|
model = model_class(config).to(torch_device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "blip" in model.__class__.__name__.lower(): |
|
|
model_to_be_compiled = model.language_model |
|
|
else: |
|
|
model_to_be_compiled = model |
|
|
|
|
|
|
|
|
main_input = inputs_dict[model.main_input_name].to(torch_device) |
|
|
half_batch_size = main_input.shape[0] // 2 |
|
|
input_1 = {} |
|
|
input_2 = {} |
|
|
for key, value in inputs_dict.items(): |
|
|
if isinstance(value, torch.Tensor): |
|
|
input_1[key] = value[:half_batch_size, :].to(torch_device) |
|
|
input_2[key] = value[half_batch_size : half_batch_size * 2, :].to(torch_device) |
|
|
else: |
|
|
input_1[key] = value |
|
|
input_2[key] = value |
|
|
model_input_sets = [input_1, input_2] |
|
|
self.assertTrue( |
|
|
model_input_sets[0][model.main_input_name].shape == model_input_sets[1][model.main_input_name].shape |
|
|
) |
|
|
|
|
|
|
|
|
torch.compiler.reset() |
|
|
has_defined_cache_implementation = model.generation_config.cache_implementation is not None |
|
|
compile_config = CompileConfig(dynamic=False) |
|
|
compile_config._compile_all_devices = True |
|
|
|
|
|
generation_kwargs = { |
|
|
"do_sample": False, |
|
|
"max_new_tokens": 5, |
|
|
"return_dict_in_generate": True, |
|
|
"output_scores": True, |
|
|
"compile_config": compile_config, |
|
|
} |
|
|
|
|
|
|
|
|
dynamic_outputs = [] |
|
|
|
|
|
|
|
|
with torch.compiler.set_stance("force_eager"): |
|
|
for model_inputs in model_input_sets: |
|
|
gen_out = model.generate(**model_inputs, **generation_kwargs) |
|
|
dynamic_outputs.append(gen_out) |
|
|
|
|
|
if not has_defined_cache_implementation: |
|
|
decoder_cache = ( |
|
|
gen_out.past_key_values.self_attention_cache |
|
|
if config.is_encoder_decoder |
|
|
else gen_out.past_key_values |
|
|
) |
|
|
self.assertTrue(isinstance(decoder_cache, DynamicCache)) |
|
|
self.assertFalse(decoder_cache.is_compileable) |
|
|
|
|
|
self.assertFalse(hasattr(model_to_be_compiled, "_compiled_call")) |
|
|
|
|
|
|
|
|
if not has_defined_cache_implementation: |
|
|
generation_kwargs["cache_implementation"] = "static" |
|
|
|
|
|
compiled_outputs = [] |
|
|
|
|
|
torch._logging.set_logs(recompiles_verbose=True) |
|
|
logger = logging.get_logger("torch._dynamo.guards") |
|
|
with CaptureLogger(logger) as cl: |
|
|
for model_inputs in model_input_sets: |
|
|
|
|
|
gen_out = model.generate(**model_inputs, **generation_kwargs) |
|
|
compiled_outputs.append(gen_out) |
|
|
|
|
|
decoder_cache = ( |
|
|
gen_out.past_key_values.self_attention_cache |
|
|
if config.is_encoder_decoder |
|
|
else gen_out.past_key_values |
|
|
) |
|
|
self.assertFalse(isinstance(decoder_cache, DynamicCache)) |
|
|
self.assertTrue(decoder_cache.is_compileable) |
|
|
|
|
|
self.assertTrue(hasattr(model_to_be_compiled, "_compiled_call")) |
|
|
|
|
|
if "Recompiling" in cl.out or ("guard" in cl.out and "failure" in cl.out): |
|
|
raise RuntimeError( |
|
|
f"`torch.compile` recompiled part of the forward pass in {model.__class__.__name__}. " |
|
|
"See the test logs for more details." |
|
|
) |
|
|
|
|
|
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs): |
|
|
self._check_similar_generate_outputs(dynamic_result, compiled_result) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_generate_compilation_all_outputs(self): |
|
|
""" |
|
|
Tests that all optional outputs are behaving as expected when compilation is triggered. |
|
|
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered. |
|
|
""" |
|
|
for model_class in self.all_generative_model_classes: |
|
|
if not model_class._supports_static_cache: |
|
|
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
if self.has_attentions: |
|
|
config._attn_implementation = "eager" |
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
|
|
|
torch.compiler.reset() |
|
|
has_defined_cache_implementation = model.generation_config.cache_implementation is not None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
compile_config = CompileConfig() |
|
|
compile_config._compile_all_devices = True |
|
|
if "blip" in model.__class__.__name__.lower(): |
|
|
model.language_model.generation_config.compile_config = compile_config |
|
|
if not has_defined_cache_implementation: |
|
|
model.language_model.generation_config.cache_implementation = "static" |
|
|
else: |
|
|
|
|
|
model.generation_config.compile_config = compile_config |
|
|
if not has_defined_cache_implementation: |
|
|
model.generation_config.cache_implementation = "static" |
|
|
|
|
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) |
|
|
output_generate = model.generate( |
|
|
do_sample=False, |
|
|
num_beams=1, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
min_new_tokens=self.max_new_tokens, |
|
|
output_attentions=True, |
|
|
output_hidden_states=True, |
|
|
output_scores=True, |
|
|
output_logits=True, |
|
|
return_dict_in_generate=True, |
|
|
use_cache=True, |
|
|
**logits_processor_kwargs, |
|
|
**inputs_dict, |
|
|
) |
|
|
|
|
|
if "blip" in model.__class__.__name__.lower(): |
|
|
self.assertTrue(hasattr(model.language_model, "_compiled_call")) |
|
|
else: |
|
|
self.assertTrue(hasattr(model, "_compiled_call")) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
|
|
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) |
|
|
else: |
|
|
self.assertTrue( |
|
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] |
|
|
) |
|
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) |
|
|
|
|
|
self._check_generate_outputs(output_generate, model.config, use_cache=True) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_generate_methods_with_logits_to_keep(self): |
|
|
for model_class in self.all_generative_model_classes: |
|
|
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): |
|
|
self.skipTest(reason="This model does not support `logits_to_keep` argument.") |
|
|
|
|
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
config.use_cache = True |
|
|
config.is_decoder = True |
|
|
|
|
|
model = model_class(config).to(torch_device).eval() |
|
|
|
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
"max_new_tokens": 10, |
|
|
"do_sample": False, |
|
|
} |
|
|
|
|
|
|
|
|
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0) |
|
|
|
|
|
without_all_logits = model.generate(**inputs_dict, **generation_kwargs) |
|
|
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) |
|
|
|
|
|
@pytest.mark.generate |
|
|
def test_inherits_generation_mixin(self): |
|
|
""" |
|
|
Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel` |
|
|
to inherit it. |
|
|
""" |
|
|
for model_class in self.all_generative_model_classes: |
|
|
self.assertTrue("GenerationMixin" in str(model_class.__bases__)) |
|
|
|
|
|
def _test_attention_implementation(self, attn_implementation): |
|
|
""" |
|
|
Compares the output of generate with the eager attention implementation against other implementations. |
|
|
NOTE: despite the test logic being the same, different implementations actually need different decorators, hence |
|
|
this separate function. |
|
|
""" |
|
|
max_new_tokens = 30 |
|
|
support_flag = { |
|
|
"sdpa": "_supports_sdpa", |
|
|
"flash_attention_2": "_supports_flash_attn_2", |
|
|
} |
|
|
|
|
|
for model_class in self.all_generative_model_classes: |
|
|
if not getattr(model_class, support_flag[attn_implementation]): |
|
|
self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`") |
|
|
|
|
|
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate() |
|
|
inputs_dict = {} |
|
|
for input_name, input_data in original_inputs_dict.items(): |
|
|
if isinstance(input_data, torch.Tensor) and input_data.dtype in [torch.float32, torch.bfloat16]: |
|
|
inputs_dict[input_name] = input_data.to(torch.float16) |
|
|
else: |
|
|
inputs_dict[input_name] = input_data |
|
|
main_input = inputs_dict[model_class.main_input_name] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attn_implementation == "flash_attention_2": |
|
|
for input_name in ("attention_mask", "decoder_attention_mask", "encoder_attention_mask"): |
|
|
if input_name in inputs_dict: |
|
|
inputs_dict[input_name] = torch.ones_like(inputs_dict[input_name]) |
|
|
|
|
|
|
|
|
if hasattr(config, "max_position_embeddings"): |
|
|
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1 |
|
|
|
|
|
model = model_class(config) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
model.save_pretrained(tmpdirname) |
|
|
del model |
|
|
gc.collect() |
|
|
|
|
|
generate_kwargs = { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"do_sample": False, |
|
|
"return_dict_in_generate": True, |
|
|
"output_scores": True, |
|
|
"use_cache": True, |
|
|
} |
|
|
|
|
|
model_eager = model_class.from_pretrained( |
|
|
tmpdirname, |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True, |
|
|
attn_implementation="eager", |
|
|
).to(torch_device) |
|
|
res_eager = model_eager.generate(**inputs_dict, **generate_kwargs) |
|
|
del model_eager |
|
|
gc.collect() |
|
|
|
|
|
model_attn = model_class.from_pretrained( |
|
|
tmpdirname, |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True, |
|
|
attn_implementation=attn_implementation, |
|
|
).to(torch_device) |
|
|
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs) |
|
|
del model_attn |
|
|
gc.collect() |
|
|
|
|
|
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3) |
|
|
|
|
|
@pytest.mark.generate |
|
|
@require_torch_sdpa |
|
|
@slow |
|
|
def test_eager_matches_sdpa_generate(self): |
|
|
"""Tests that generate has equivalent outputs with SDPA and eager attention implementations.""" |
|
|
self._test_attention_implementation("sdpa") |
|
|
|
|
|
@pytest.mark.flash_attn_test |
|
|
@require_flash_attn |
|
|
@require_torch_gpu |
|
|
@slow |
|
|
def test_eager_matches_fa2_generate(self): |
|
|
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations.""" |
|
|
self._test_attention_implementation("flash_attention_2") |
|
|
|
|
|
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): |
|
|
input_batch_size = int(output.sequences.shape[0] / num_return_sequences) |
|
|
internal_batch_size = ( |
|
|
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences |
|
|
) |
|
|
|
|
|
prompt_length = getattr(self.model_tester, "seq_length", None) |
|
|
prompt_length = getattr(self.model_tester, "encoder_seq_length", prompt_length) |
|
|
prompt_length = getattr(self.model_tester, "text_seq_length", prompt_length) |
|
|
|
|
|
config = config.text_config if hasattr(config, "text_config") else config |
|
|
|
|
|
generated_length = ( |
|
|
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length |
|
|
) |
|
|
decoder_past_key_values = getattr(output, "past_key_values", None) |
|
|
if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache): |
|
|
decoder_past_key_values = decoder_past_key_values.self_attention_cache |
|
|
|
|
|
|
|
|
if hasattr(self.model_tester, "get_subsampled_output_lengths"): |
|
|
prompt_length = self.model_tester.get_subsampled_output_lengths(prompt_length) |
|
|
|
|
|
|
|
|
self._check_scores( |
|
|
batch_size=internal_batch_size, scores=output.scores, generated_length=generated_length, config=config |
|
|
) |
|
|
|
|
|
|
|
|
self._check_logits(batch_size=internal_batch_size, logits=output.logits, config=config) |
|
|
|
|
|
|
|
|
if self.has_attentions: |
|
|
if config.is_encoder_decoder: |
|
|
|
|
|
self._check_encoder_attention_for_generate( |
|
|
attentions=output.encoder_attentions, |
|
|
batch_size=input_batch_size, |
|
|
config=config, |
|
|
prompt_length=prompt_length, |
|
|
) |
|
|
|
|
|
self._check_attentions_for_generate( |
|
|
batch_size=internal_batch_size, |
|
|
attentions=output.decoder_attentions, |
|
|
prompt_length=1, |
|
|
output_length=output.sequences.shape[-1], |
|
|
config=config, |
|
|
decoder_past_key_values=decoder_past_key_values, |
|
|
) |
|
|
else: |
|
|
self._check_attentions_for_generate( |
|
|
batch_size=internal_batch_size, |
|
|
attentions=output.attentions, |
|
|
prompt_length=prompt_length, |
|
|
output_length=output.sequences.shape[-1], |
|
|
config=config, |
|
|
decoder_past_key_values=decoder_past_key_values, |
|
|
) |
|
|
|
|
|
|
|
|
if config.is_encoder_decoder: |
|
|
|
|
|
self._check_encoder_hidden_states_for_generate( |
|
|
hidden_states=output.encoder_hidden_states, |
|
|
batch_size=input_batch_size, |
|
|
config=config, |
|
|
prompt_length=prompt_length, |
|
|
) |
|
|
|
|
|
self._check_hidden_states_for_generate( |
|
|
batch_size=internal_batch_size, |
|
|
hidden_states=output.decoder_hidden_states, |
|
|
prompt_length=1, |
|
|
output_length=output.sequences.shape[-1], |
|
|
config=config, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
else: |
|
|
self._check_hidden_states_for_generate( |
|
|
batch_size=internal_batch_size, |
|
|
hidden_states=output.hidden_states, |
|
|
prompt_length=prompt_length, |
|
|
output_length=output.sequences.shape[-1], |
|
|
config=config, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models_without_standard_cache = ( |
|
|
"bamba", |
|
|
"ctrl", |
|
|
"fsmt", |
|
|
"gptbigcode", |
|
|
"mega", |
|
|
"reformer", |
|
|
"jamba", |
|
|
"mamba", |
|
|
"xlnet", |
|
|
"zamba", |
|
|
"zamba2", |
|
|
) |
|
|
has_standard_cache = not any( |
|
|
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache |
|
|
) |
|
|
if has_standard_cache: |
|
|
if use_cache: |
|
|
cache_length = output.sequences.shape[-1] - 1 |
|
|
self._check_past_key_values_for_generate( |
|
|
batch_size=internal_batch_size, |
|
|
decoder_past_key_values=decoder_past_key_values, |
|
|
cache_length=cache_length, |
|
|
config=config, |
|
|
) |
|
|
elif use_cache is False: |
|
|
self.assertTrue(decoder_past_key_values is None) |
|
|
|
|
|
def _check_scores(self, batch_size, scores, generated_length, config): |
|
|
vocab_size = config.get_text_config(decoder=True).vocab_size |
|
|
expected_shape = (batch_size, vocab_size) |
|
|
self.assertIsInstance(scores, tuple) |
|
|
self.assertEqual(len(scores), generated_length) |
|
|
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) |
|
|
|
|
|
def _check_logits(self, batch_size, logits, config): |
|
|
vocab_size = config.get_text_config(decoder=True).vocab_size |
|
|
self.assertIsInstance(logits, tuple) |
|
|
self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits)) |
|
|
|
|
|
vocab_diff = vocab_size - logits[0].shape[-1] |
|
|
self.assertTrue(vocab_diff in [0, 1]) |
|
|
self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits)) |
|
|
|
|
|
def _check_attentions_for_generate( |
|
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values |
|
|
): |
|
|
self.assertIsInstance(attentions, tuple) |
|
|
self.assertListEqual( |
|
|
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) |
|
|
) |
|
|
self.assertEqual(len(attentions), (output_length - prompt_length)) |
|
|
|
|
|
use_cache = decoder_past_key_values is not None |
|
|
has_static_cache = isinstance(decoder_past_key_values, (StaticCache, HybridCache)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for generated_length, iter_attentions in enumerate(attentions): |
|
|
|
|
|
if use_cache and generated_length > 0: |
|
|
model_input_length = 1 |
|
|
else: |
|
|
model_input_length = prompt_length + generated_length |
|
|
query_length = ( |
|
|
prompt_length + generated_length |
|
|
if not has_static_cache |
|
|
else decoder_past_key_values.get_max_cache_shape() |
|
|
) |
|
|
|
|
|
expected_shape = ( |
|
|
batch_size, |
|
|
config.num_attention_heads, |
|
|
model_input_length, |
|
|
query_length, |
|
|
) |
|
|
|
|
|
self.assertListEqual( |
|
|
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) |
|
|
) |
|
|
|
|
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length): |
|
|
encoder_expected_shape = (batch_size, config.num_attention_heads, prompt_length, prompt_length) |
|
|
self.assertIsInstance(attentions, tuple) |
|
|
self.assertListEqual( |
|
|
[layer_attentions.shape for layer_attentions in attentions], |
|
|
[encoder_expected_shape] * len(attentions), |
|
|
) |
|
|
|
|
|
def _check_hidden_states_for_generate( |
|
|
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False |
|
|
): |
|
|
self.assertIsInstance(hidden_states, tuple) |
|
|
self.assertListEqual( |
|
|
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], |
|
|
[True] * len(hidden_states), |
|
|
) |
|
|
self.assertEqual(len(hidden_states), (output_length - prompt_length)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for generated_length, iter_hidden_states in enumerate(hidden_states): |
|
|
|
|
|
if use_cache and generated_length > 0: |
|
|
model_input_length = 1 |
|
|
else: |
|
|
model_input_length = prompt_length + generated_length |
|
|
expected_shape = (batch_size, model_input_length, config.hidden_size) |
|
|
|
|
|
self.assertListEqual( |
|
|
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], |
|
|
[expected_shape] * len(iter_hidden_states), |
|
|
) |
|
|
|
|
|
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length): |
|
|
encoder_expected_shape = (batch_size, prompt_length, config.hidden_size) |
|
|
self.assertIsInstance(hidden_states, tuple) |
|
|
self.assertListEqual( |
|
|
[layer_hidden_states.shape for layer_hidden_states in hidden_states], |
|
|
[encoder_expected_shape] * len(hidden_states), |
|
|
) |
|
|
|
|
|
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): |
|
|
self.assertIsInstance(decoder_past_key_values, (tuple, Cache)) |
|
|
|
|
|
|
|
|
expected_shape = ( |
|
|
batch_size, |
|
|
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, |
|
|
cache_length, |
|
|
config.hidden_size // config.num_attention_heads, |
|
|
) |
|
|
|
|
|
if isinstance(decoder_past_key_values, Cache): |
|
|
self.assertListEqual( |
|
|
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], |
|
|
[expected_shape] * len(decoder_past_key_values.key_cache), |
|
|
) |
|
|
self.assertListEqual( |
|
|
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], |
|
|
[expected_shape] * len(decoder_past_key_values.value_cache), |
|
|
) |
|
|
|
|
|
|
|
|
else: |
|
|
self.assertListEqual( |
|
|
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values], |
|
|
[True] * len(decoder_past_key_values), |
|
|
) |
|
|
|
|
|
self.assertListEqual( |
|
|
[layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values], |
|
|
[expected_shape] * len(decoder_past_key_values), |
|
|
) |
|
|
self.assertListEqual( |
|
|
[layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values], |
|
|
[expected_shape] * len(decoder_past_key_values), |
|
|
) |
|
|
|
|
|
def _check_sequence_inside_sequence(self, tensor_1, tensor_2): |
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(tensor_1, list): |
|
|
tensor_1 = tensor_1.tolist() |
|
|
if not isinstance(tensor_2, list): |
|
|
tensor_2 = tensor_2.tolist() |
|
|
|
|
|
in_order = len(tensor_1) <= len(tensor_2) |
|
|
longer = tensor_2 if in_order else tensor_1 |
|
|
shorter = tensor_1 if in_order else tensor_2 |
|
|
|
|
|
flag = False |
|
|
chunk_size = len(shorter) |
|
|
for chunk_idx in range(len(longer) - chunk_size + 1): |
|
|
subseq = longer[chunk_idx : chunk_idx + chunk_size] |
|
|
if subseq == shorter: |
|
|
flag = True |
|
|
break |
|
|
|
|
|
self.assertTrue(flag) |
|
|
|
|
|
|
|
|
@require_torch |
|
|
class UtilsFunctionsTest(unittest.TestCase): |
|
|
def test_speculative_sampling(self): |
|
|
|
|
|
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) |
|
|
candidate_logits = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], |
|
|
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], |
|
|
[-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], |
|
|
] |
|
|
] |
|
|
) |
|
|
candidate_length = 3 |
|
|
inf = float("inf") |
|
|
new_logits = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], |
|
|
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], |
|
|
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 10.0, -inf], |
|
|
[-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], |
|
|
] |
|
|
] |
|
|
) |
|
|
last_assistant_token_is_eos = False |
|
|
validated_tokens, n_matches = _speculative_sampling( |
|
|
candidate_input_ids, |
|
|
candidate_logits, |
|
|
candidate_length, |
|
|
new_logits, |
|
|
last_assistant_token_is_eos, |
|
|
) |
|
|
self.assertTrue(n_matches.item() == 2) |
|
|
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8]) |
|
|
|
|
|
def test_speculative_sampling_target_distribution(self): |
|
|
""" |
|
|
Asserts that the target distribution is preserved. |
|
|
Should help with catching issues like #32867. |
|
|
""" |
|
|
|
|
|
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) |
|
|
candidate_logits = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], |
|
|
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], |
|
|
[-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], |
|
|
] |
|
|
] |
|
|
) |
|
|
candidate_length = 3 |
|
|
inf = float("inf") |
|
|
new_logits = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
|
|
|
[-inf, 10.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], |
|
|
|
|
|
[-inf, -inf, -inf, -inf, 10.0, -inf, -inf, -inf, -inf, -inf], |
|
|
|
|
|
[-inf, 2.0, -inf, 1.0, -inf, -inf, -inf, -0.01, 2.0, -inf], |
|
|
|
|
|
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], |
|
|
] |
|
|
] |
|
|
) |
|
|
last_assistant_token_is_eos = False |
|
|
last_validated_token = [] |
|
|
for _ in range(10_000): |
|
|
validated_tokens, n_matches = _speculative_sampling( |
|
|
candidate_input_ids, |
|
|
candidate_logits, |
|
|
candidate_length, |
|
|
new_logits, |
|
|
last_assistant_token_is_eos, |
|
|
) |
|
|
self.assertTrue(n_matches.item() == 2) |
|
|
self.assertTrue(validated_tokens.tolist()[0][0] == 1) |
|
|
self.assertTrue(validated_tokens.tolist()[0][1] == 4) |
|
|
self.assertTrue(validated_tokens.tolist()[0][2] in [1, 3, 7, 8]) |
|
|
last_validated_token.append(validated_tokens.tolist()[0][2]) |
|
|
|
|
|
last_token_counts = collections.Counter(last_validated_token) |
|
|
self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0) |
|
|
self.assertTrue(last_token_counts[8] > last_token_counts[3]) |
|
|
|
|
|
def test_cache_dependant_input_preparation_exporting(self): |
|
|
self.assertFalse( |
|
|
is_torchdynamo_exporting() |
|
|
) |
|
|
|
|
|
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0] |
|
|
inputs_embeds = torch.rand((2, 8), dtype=torch.float32) |
|
|
cache_position = torch.arange(0, 8, dtype=torch.int64) |
|
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position) |
|
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( |
|
|
input_ids, inputs_embeds, cache_position |
|
|
) |
|
|
torch.testing.assert_close(eager1, export1) |
|
|
torch.testing.assert_close(eager2, export2) |
|
|
|
|
|
|
|
|
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) |
|
|
inputs_embeds = torch.rand((2, 8), dtype=torch.float32) |
|
|
cache_position = torch.arange(0, 8, dtype=torch.int64) |
|
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position) |
|
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( |
|
|
input_ids, inputs_embeds, cache_position |
|
|
) |
|
|
torch.testing.assert_close(eager1, export1) |
|
|
torch.testing.assert_close(eager2, export2) |
|
|
|
|
|
|
|
|
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64) |
|
|
inputs_embeds = None |
|
|
cache_position = torch.arange(0, 8, dtype=torch.int64) |
|
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position) |
|
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( |
|
|
input_ids, inputs_embeds, cache_position |
|
|
) |
|
|
torch.testing.assert_close(eager1, export1) |
|
|
torch.testing.assert_close(eager2, export2) |
|
|
|
|
|
|
|
|
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) |
|
|
inputs_embeds = None |
|
|
cache_position = torch.arange(0, 8, dtype=torch.int64) |
|
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position) |
|
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( |
|
|
input_ids, inputs_embeds, cache_position |
|
|
) |
|
|
torch.testing.assert_close(eager1, export1) |
|
|
torch.testing.assert_close(eager2, export2) |
|
|
|
|
|
|
|
|
global_rng = random.Random() |
|
|
|
|
|
|
|
|
|
|
|
def ids_tensor(shape, vocab_size, rng=None, name=None): |
|
|
|
|
|
if rng is None: |
|
|
rng = global_rng |
|
|
|
|
|
total_dims = 1 |
|
|
for dim in shape: |
|
|
total_dims *= dim |
|
|
|
|
|
values = [] |
|
|
for _ in range(total_dims): |
|
|
values.append(rng.randint(0, vocab_size - 1)) |
|
|
|
|
|
return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
def floats_tensor(shape, scale=1.0, rng=None, name=None): |
|
|
"""Creates a random float32 tensor""" |
|
|
if rng is None: |
|
|
rng = global_rng |
|
|
|
|
|
total_dims = 1 |
|
|
for dim in shape: |
|
|
total_dims *= dim |
|
|
|
|
|
values = [] |
|
|
for _ in range(total_dims): |
|
|
values.append(rng.random() * scale) |
|
|
|
|
|
return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous() |
|
|
|
|
|
|
|
|
@pytest.mark.generate |
|
|
@require_torch |
|
|
class GenerationIntegrationTests(unittest.TestCase): |
|
|
@slow |
|
|
def test_diverse_beam_search(self): |
|
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood. |
|
|
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People. |
|
|
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports. |
|
|
The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both.""" |
|
|
|
|
|
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") |
|
|
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device) |
|
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
|
|
outputs = bart_model.generate( |
|
|
input_ids, |
|
|
num_beams=4, |
|
|
num_return_sequences=2, |
|
|
num_beam_groups=4, |
|
|
diversity_penalty=2.0, |
|
|
remove_invalid_values=True, |
|
|
) |
|
|
|
|
|
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
self.assertListEqual( |
|
|
generated_text, |
|
|
[ |
|
|
"The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the" |
|
|
" middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle" |
|
|
" name, as well as his father's first. It is the first baby for both of them.", |
|
|
"Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the" |
|
|
" first child for both. The couple announced the pregnancy in January. The name Silas is the middle" |
|
|
" name of Timberlake's maternal grandfather. It's also his own middle name.", |
|
|
], |
|
|
) |
|
|
|
|
|
def test_max_length_if_input_embeds(self): |
|
|
article = "Today a dragon flew over Paris." |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
inputs_embeds = model.get_input_embeddings()(input_ids) |
|
|
|
|
|
max_length = 20 |
|
|
input_len = input_ids.shape[-1] |
|
|
out_gen = model.generate(input_ids=input_ids, max_length=max_length) |
|
|
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length) |
|
|
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) |
|
|
|
|
|
def test_min_length_if_input_embeds(self): |
|
|
article = "Today a dragon flew over Paris." |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
inputs_embeds = model.get_input_embeddings()(input_ids) |
|
|
|
|
|
min_length = 10 |
|
|
input_len = input_ids.shape[-1] |
|
|
out_gen = model.generate(input_ids=input_ids, min_length=min_length) |
|
|
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length) |
|
|
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) |
|
|
|
|
|
def test_custom_stopping_criteria_overload_error(self): |
|
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
|
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
|
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
|
|
|
|
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
stopping_criteria = StoppingCriteriaList() |
|
|
stopping_criteria.append(MaxLengthCriteria(max_length=42)) |
|
|
with self.assertRaises(ValueError): |
|
|
bart_model.generate(input_ids, stopping_criteria=stopping_criteria) |
|
|
with self.assertRaises(ValueError): |
|
|
bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32) |
|
|
|
|
|
def test_custom_stopping_criteria(self): |
|
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
|
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
|
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
|
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
|
|
class DummyCriteria(StoppingCriteria): |
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
|
return input_ids.shape[-1] >= 20 |
|
|
|
|
|
stopping_criteria = StoppingCriteriaList() |
|
|
stopping_criteria.append(DummyCriteria()) |
|
|
|
|
|
self.assertEqual( |
|
|
list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=22).shape), |
|
|
[1, 20], |
|
|
) |
|
|
self.assertEqual( |
|
|
list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=18).shape), |
|
|
[1, 18], |
|
|
) |
|
|
|
|
|
|
|
|
def test_stop_sequence_stopping_criteria(self): |
|
|
prompt = """Hello I believe in""" |
|
|
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") |
|
|
output = generator(prompt) |
|
|
self.assertEqual( |
|
|
output, |
|
|
[{"generated_text": ("Hello I believe in we we we we we we we we we")}], |
|
|
) |
|
|
|
|
|
output = generator(prompt, stop_sequence=" we") |
|
|
self.assertEqual(output, [{"generated_text": "Hello I believe in we"}]) |
|
|
|
|
|
def test_generate_non_nlp_input_ids_as_kwarg(self): |
|
|
model = ImageGPTForCausalImageModeling.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-imagegpt", max_length=10 |
|
|
).to(torch_device) |
|
|
input_ids = ids_tensor((3, 5), vocab_size=10) |
|
|
|
|
|
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu() |
|
|
output_sequences = model.generate(input_ids).cpu() |
|
|
|
|
|
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) |
|
|
self.assertEqual(output_sequences.shape, (3, 10)) |
|
|
|
|
|
def test_generate_input_values_as_encoder_kwarg(self): |
|
|
input_values = floats_tensor((2, 250)) |
|
|
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder") |
|
|
model = model.to(torch_device) |
|
|
output_sequences_kwargs = model.generate(input_values=input_values, max_length=5).cpu() |
|
|
output_sequences = model.generate(input_values, max_length=5).cpu() |
|
|
|
|
|
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) |
|
|
self.assertEqual(output_sequences.shape, (2, 5)) |
|
|
|
|
|
def test_transition_scores_group_beam_search_encoder_decoder(self): |
|
|
articles = [ |
|
|
"Justin Timberlake and Jessica Biel, welcome to parenthood.", |
|
|
"Michael Phelps is arguably the most decorated Olympian of all time.", |
|
|
] |
|
|
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
model = BartForConditionalGeneration.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-bart", |
|
|
max_length=10, |
|
|
num_beams=2, |
|
|
num_beam_groups=2, |
|
|
num_return_sequences=2, |
|
|
diversity_penalty=1.0, |
|
|
eos_token_id=None, |
|
|
return_dict_in_generate=True, |
|
|
output_scores=True, |
|
|
length_penalty=0.0, |
|
|
) |
|
|
model = model.to(torch_device) |
|
|
|
|
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) |
|
|
outputs = model.generate(input_ids=input_ids) |
|
|
|
|
|
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices) |
|
|
transition_scores_sum = transition_scores.sum(-1) |
|
|
|
|
|
torch.testing.assert_close(transition_scores_sum, outputs.sequences_scores, rtol=1e-3, atol=1e-3) |
|
|
|
|
|
@slow |
|
|
def test_green_red_watermark_generation(self): |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device) |
|
|
input_len = model_inputs["input_ids"].shape[-1] |
|
|
|
|
|
|
|
|
watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash") |
|
|
_ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15) |
|
|
|
|
|
|
|
|
|
|
|
args = { |
|
|
"bias": 2.0, |
|
|
"context_width": 1, |
|
|
"seeding_scheme": "selfhash", |
|
|
"greenlist_ratio": 0.25, |
|
|
"hashing_key": 15485863, |
|
|
} |
|
|
output = model.generate(**model_inputs, do_sample=False, max_length=15) |
|
|
output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15) |
|
|
|
|
|
|
|
|
detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args) |
|
|
detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True) |
|
|
detection_out = detector(output[:, input_len:], return_dict=True) |
|
|
|
|
|
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True]) |
|
|
self.assertListEqual(detection_out.prediction.tolist(), [False]) |
|
|
|
|
|
"""Check the mean bias inserted by the watermarking algorithm.""" |
|
|
|
|
|
@slow |
|
|
def test_synthid_text_watermark_generation_mean_expected_bias(self): |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device) |
|
|
input_len = 5 |
|
|
batch_size = 200 |
|
|
|
|
|
|
|
|
watermark_config = SynthIDTextWatermarkingConfig(keys=[10, 20], ngram_len=5, debug_mode=True) |
|
|
logits_processor = watermark_config.construct_processor(model.config.vocab_size, torch_device) |
|
|
mean_g_values_repeats = [] |
|
|
for _ in range(40): |
|
|
input_ids = torch.zeros( |
|
|
(batch_size, input_len), |
|
|
dtype=torch.int64, |
|
|
device=torch_device, |
|
|
) |
|
|
model_inputs = { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": torch.ones_like(input_ids, device=torch_device), |
|
|
} |
|
|
output = model.generate( |
|
|
**model_inputs, watermarking_config=watermark_config, do_sample=True, max_length=500, top_k=1000 |
|
|
) |
|
|
g_values = logits_processor.compute_g_values(input_ids=output[:, input_len:]) |
|
|
context_repetition_mask = logits_processor.compute_context_repetition_mask( |
|
|
input_ids=output[:, input_len:], |
|
|
).unsqueeze(dim=2) |
|
|
|
|
|
mean_g_values = torch.masked.mean( |
|
|
g_values, |
|
|
mask=context_repetition_mask, |
|
|
dim=0, |
|
|
keepdim=True, |
|
|
dtype=torch.float64, |
|
|
) |
|
|
mean_g_values_repeats.append(mean_g_values) |
|
|
|
|
|
mean_g_values = torch.concat(mean_g_values_repeats, dim=0).mean(dim=0) |
|
|
expected_mean_g_value = logits_processor.expected_mean_g_value( |
|
|
vocab_size=model.config.vocab_size, |
|
|
) |
|
|
atol = 0.03 |
|
|
is_close = torch.isclose( |
|
|
mean_g_values, |
|
|
torch.tensor(expected_mean_g_value, dtype=torch.float64), |
|
|
atol=atol, |
|
|
rtol=0, |
|
|
) |
|
|
self.assertTrue(torch.all(is_close)) |
|
|
|
|
|
@slow |
|
|
def test_beam_search_example_integration(self): |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") |
|
|
|
|
|
encoder_input_str = "translate English to German: How old are you?" |
|
|
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
|
|
|
|
|
|
|
|
num_beams = 3 |
|
|
|
|
|
input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long) |
|
|
input_ids = input_ids * model.config.decoder_start_token_id |
|
|
|
|
|
|
|
|
model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)} |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids, num_beams=num_beams, min_length=5, eos_token_id=model.config.eos_token_id, **model_kwargs |
|
|
) |
|
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
self.assertListEqual(outputs, ["Wie alt bist du?"]) |
|
|
|
|
|
@slow |
|
|
def test_constrained_beam_search(self): |
|
|
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) |
|
|
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") |
|
|
|
|
|
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids |
|
|
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids |
|
|
|
|
|
constraints = [ |
|
|
PhrasalConstraint(force_tokens), |
|
|
PhrasalConstraint(force_tokens_2), |
|
|
] |
|
|
|
|
|
starting_text = ["The soldiers were not prepared and"] |
|
|
|
|
|
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids, |
|
|
constraints=constraints, |
|
|
num_beams=10, |
|
|
num_return_sequences=1, |
|
|
no_repeat_ngram_size=1, |
|
|
max_length=30, |
|
|
remove_invalid_values=True, |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
self.assertListEqual( |
|
|
generated_text, |
|
|
[ |
|
|
"The soldiers were not prepared and didn't know what to do. They had no idea how they would react if" |
|
|
" the enemy attacked them, big weapons scared" |
|
|
], |
|
|
) |
|
|
|
|
|
@slow |
|
|
def test_constrained_beam_search_mixed(self): |
|
|
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) |
|
|
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") |
|
|
|
|
|
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids |
|
|
flexible_phrases = tokenizer( |
|
|
["scream", "screams", "screaming", "screamed"], add_prefix_space=True, add_special_tokens=False |
|
|
).input_ids |
|
|
|
|
|
constraints = [ |
|
|
PhrasalConstraint(force_phrase), |
|
|
DisjunctiveConstraint(flexible_phrases), |
|
|
] |
|
|
|
|
|
starting_text = ["The soldiers", "The child"] |
|
|
|
|
|
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids, |
|
|
constraints=constraints, |
|
|
num_beams=10, |
|
|
num_return_sequences=1, |
|
|
no_repeat_ngram_size=1, |
|
|
|
|
|
remove_invalid_values=True, |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
self.assertListEqual( |
|
|
generated_text, |
|
|
[ |
|
|
"The soldiers, who had been stationed at the base for more than a year before being evacuated" |
|
|
" screaming scared", |
|
|
"The child was taken to a local hospital where he died.\n 'I don't think screaming scared", |
|
|
], |
|
|
) |
|
|
|
|
|
@slow |
|
|
def test_constrained_beam_search_mixed_mixin(self): |
|
|
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) |
|
|
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") |
|
|
|
|
|
force_word = "scared" |
|
|
force_flexible = ["scream", "screams", "screaming", "screamed"] |
|
|
|
|
|
force_words_ids = [ |
|
|
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids, |
|
|
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids, |
|
|
] |
|
|
|
|
|
starting_text = ["The soldiers", "The child"] |
|
|
|
|
|
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids, |
|
|
force_words_ids=force_words_ids, |
|
|
num_beams=10, |
|
|
num_return_sequences=1, |
|
|
no_repeat_ngram_size=1, |
|
|
remove_invalid_values=True, |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
self.assertListEqual( |
|
|
generated_text, |
|
|
[ |
|
|
"The soldiers, who had been stationed at the base for more than a year before being evacuated" |
|
|
" screaming scared", |
|
|
"The child was taken to a local hospital where he died.\n 'I don't think screaming scared", |
|
|
], |
|
|
) |
|
|
|
|
|
@slow |
|
|
def test_cfg_mixin(self): |
|
|
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) |
|
|
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") |
|
|
|
|
|
input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True) |
|
|
input["input_ids"] = input["input_ids"].to(torch_device) |
|
|
input["attention_mask"] = input["attention_mask"].to(torch_device) |
|
|
|
|
|
outputs = model.generate(**input, max_new_tokens=32, guidance_scale=1.5) |
|
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
self.assertListEqual( |
|
|
generated_text, |
|
|
[ |
|
|
"The dragon flew over Paris, landing in the Rue de la Bastille. The crowd was so excited " |
|
|
'that they had to leave the city.\n\n"We\'re going to Paris!"\n' |
|
|
], |
|
|
) |
|
|
|
|
|
neg = tokenizer(["France,"], return_tensors="pt", return_attention_mask=True) |
|
|
neg["input_ids"] = neg["input_ids"].to(torch_device) |
|
|
neg["attention_mask"] = neg["attention_mask"].to(torch_device) |
|
|
outputs = model.generate( |
|
|
**input, |
|
|
max_new_tokens=32, |
|
|
guidance_scale=1.5, |
|
|
negative_prompt_ids=neg["input_ids"], |
|
|
negative_prompt_attention_mask=neg["attention_mask"], |
|
|
) |
|
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
self.assertListEqual( |
|
|
generated_text, |
|
|
[ |
|
|
'The dragon flew over Paris, landing on the pavement.\n\n"Paris!"\n\n"Paris!"\n\n"' |
|
|
'Paris!"\n\n"Paris!"\n\n"Paris!"\n\n' |
|
|
], |
|
|
) |
|
|
|
|
|
@slow |
|
|
def test_constrained_beam_search_example_translation_mixin(self): |
|
|
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") |
|
|
|
|
|
encoder_input_str = "translate English to German: How old are you?" |
|
|
force_words = ["sind"] |
|
|
|
|
|
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
|
|
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids, |
|
|
force_words_ids=force_words_ids, |
|
|
num_beams=10, |
|
|
num_return_sequences=1, |
|
|
no_repeat_ngram_size=1, |
|
|
remove_invalid_values=True, |
|
|
) |
|
|
|
|
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
self.assertListEqual(outputs, ["Wie alt sind Sie?"]) |
|
|
|
|
|
@slow |
|
|
def test_constrained_beam_search_example_integration(self): |
|
|
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") |
|
|
|
|
|
encoder_input_str = "translate English to German: How old are you?" |
|
|
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
|
|
|
|
|
|
|
|
num_beams = 5 |
|
|
|
|
|
input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long) |
|
|
input_ids = input_ids * model.config.decoder_start_token_id |
|
|
|
|
|
|
|
|
model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)} |
|
|
|
|
|
constraint_str = "sind" |
|
|
constraint_token_ids = tokenizer.encode(constraint_str)[:-1] |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids, |
|
|
num_beams=num_beams, |
|
|
force_words_ids=[constraint_token_ids], |
|
|
min_length=5, |
|
|
eos_token_id=model.config.eos_token_id, |
|
|
**model_kwargs, |
|
|
) |
|
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
self.assertListEqual(outputs, ["Wie alt sind Sie?"]) |
|
|
|
|
|
@slow |
|
|
def test_per_row_stopping_criteria(self): |
|
|
text = [ |
|
|
"They completed the challenging puzzle, revealing the hidden", |
|
|
"Today a dragon flew over France", |
|
|
"The aroma of freshly baked pizza filled the kitchen", |
|
|
] |
|
|
stop_strings = ["secrets"] |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") |
|
|
tokenizer.padding_side = "left" |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
input_ids = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False).input_ids.to( |
|
|
torch_device |
|
|
) |
|
|
|
|
|
|
|
|
out = model.generate(input_ids, max_length=15) |
|
|
out_text = tokenizer.batch_decode(out) |
|
|
expected_out = [ |
|
|
"They completed the challenging puzzle, revealing the hidden secrets of the world.\n", |
|
|
"<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced", |
|
|
"The aroma of freshly baked pizza filled the kitchen with a sense of freshness", |
|
|
] |
|
|
self.assertListEqual(out_text, expected_out) |
|
|
|
|
|
|
|
|
out = model.generate(input_ids, max_length=15, stop_strings=stop_strings, tokenizer=tokenizer) |
|
|
out_text = tokenizer.batch_decode(out) |
|
|
expected_out = [ |
|
|
"They completed the challenging puzzle, revealing the hidden secrets<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", |
|
|
"<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced", |
|
|
"The aroma of freshly baked pizza filled the kitchen with a sense of freshness", |
|
|
] |
|
|
self.assertListEqual(out_text, expected_out) |
|
|
|
|
|
def test_constrained_beam_search_mixin_type_checks(self): |
|
|
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random") |
|
|
|
|
|
encoder_input_str = "translate English to German: How old are you?" |
|
|
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
force_words = ["sind"] |
|
|
force_words_ids = tokenizer(force_words, return_tensors="pt").input_ids |
|
|
model.generate( |
|
|
input_ids, |
|
|
force_words_ids=force_words_ids, |
|
|
num_beams=10, |
|
|
num_return_sequences=1, |
|
|
no_repeat_ngram_size=1, |
|
|
remove_invalid_values=True, |
|
|
) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
force_words = ["sind"] |
|
|
force_words_ids = [tokenizer(force_words, return_tensors="pt").input_ids] |
|
|
model.generate( |
|
|
input_ids, |
|
|
force_words_ids=force_words_ids, |
|
|
num_beams=10, |
|
|
num_return_sequences=1, |
|
|
no_repeat_ngram_size=1, |
|
|
remove_invalid_values=True, |
|
|
) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
model.generate(input_ids, force_words_ids=[]) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
model.generate(input_ids, force_words_ids=[[-1]]) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
model.generate(input_ids, force_words_ids=[[[-1]]]) |
|
|
|
|
|
def test_batched_decoder_start_id(self): |
|
|
articles = [ |
|
|
"Justin Timberlake and Jessica Biel, welcome to parenthood.", |
|
|
"Michael Phelps is arguably the most decorated Olympian of all time.", |
|
|
] |
|
|
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( |
|
|
torch_device |
|
|
) |
|
|
input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) |
|
|
decoder_start_token_id = bart_model.generation_config.decoder_start_token_id |
|
|
decoder_start_token_id_batch = [decoder_start_token_id] * input_ids.shape[0] |
|
|
|
|
|
outputs = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id) |
|
|
|
|
|
outputs_batched_ids = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id_batch) |
|
|
|
|
|
self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist()) |
|
|
|
|
|
def test_decoder_start_id_from_config(self): |
|
|
|
|
|
articles = [ |
|
|
"Justin Timberlake and Jessica Biel, welcome to parenthood.", |
|
|
"Michael Phelps is arguably the most decorated Olympian of all time.", |
|
|
] |
|
|
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( |
|
|
torch_device |
|
|
) |
|
|
input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) |
|
|
decoder_start_token_id = bart_model.generation_config.decoder_start_token_id |
|
|
|
|
|
|
|
|
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False)) |
|
|
|
|
|
|
|
|
bart_model.generation_config.decoder_start_token_id = None |
|
|
bart_model.generation_config.bos_token_id = None |
|
|
outputs_with_user_id = bart_model.generate( |
|
|
input_ids, |
|
|
generation_config=GenerationConfig(do_sample=False, decoder_start_token_id=decoder_start_token_id), |
|
|
) |
|
|
|
|
|
self.assertListEqual(outputs.tolist(), outputs_with_user_id.tolist()) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False)) |
|
|
|
|
|
def test_contrastive_search_batched(self): |
|
|
|
|
|
articles = ["Foo", "Bar Baz"] |
|
|
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device) |
|
|
|
|
|
model.config.eos_token_id = None |
|
|
input_ids_batched = tokenizer(articles, padding=True, return_tensors="pt").input_ids.to(torch_device) |
|
|
input_ids = tokenizer(articles[1], return_tensors="pt").input_ids.to(torch_device) |
|
|
|
|
|
output_sequences_batched = model.generate( |
|
|
input_ids=input_ids_batched, penalty_alpha=0.6, top_k=4, return_dict_in_generate=True, output_scores=True |
|
|
) |
|
|
output_sequences = model.generate( |
|
|
input_ids=input_ids, penalty_alpha=0.6, top_k=4, return_dict_in_generate=True, output_scores=True |
|
|
) |
|
|
|
|
|
batched_out = tokenizer.decode(output_sequences_batched.sequences[1], skip_special_tokens=True) |
|
|
out = tokenizer.decode(output_sequences.sequences[0], skip_special_tokens=True) |
|
|
self.assertEqual(batched_out, out) |
|
|
|
|
|
|
|
|
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max() |
|
|
self.assertTrue(max_score_diff < 1e-5) |
|
|
|
|
|
def test_logits_processor_not_inplace(self): |
|
|
article = "Today a dragon flew over Paris." |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
|
|
|
out = model.generate(input_ids, output_logits=True, output_scores=True, return_dict_in_generate=True) |
|
|
out_with_temp = model.generate( |
|
|
input_ids, |
|
|
temperature=0.5, |
|
|
do_sample=True, |
|
|
output_logits=True, |
|
|
output_scores=True, |
|
|
return_dict_in_generate=True, |
|
|
) |
|
|
|
|
|
|
|
|
self.assertListEqual(out.logits[-1].tolist(), out.scores[-1].tolist()) |
|
|
self.assertNotEqual(out_with_temp.logits[-1].tolist(), out_with_temp.scores[-1].tolist()) |
|
|
|
|
|
def test_eos_token_id_int_and_list_top_k_top_sampling(self): |
|
|
|
|
|
generation_kwargs = { |
|
|
"do_sample": True, |
|
|
"num_beams": 1, |
|
|
"top_p": 0.7, |
|
|
"top_k": 10, |
|
|
"temperature": 0.7, |
|
|
} |
|
|
expectation = 20 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
text = """Hello, my dog is cute and""" |
|
|
tokens = tokenizer(text, return_tensors="pt").to(torch_device) |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(1) |
|
|
eos_token_id = 846 |
|
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) |
|
|
self.assertTrue(expectation == len(generated_tokens[0])) |
|
|
|
|
|
torch.manual_seed(1) |
|
|
eos_token_id = [846, 198] |
|
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) |
|
|
self.assertTrue(expectation == len(generated_tokens[0])) |
|
|
|
|
|
def test_model_kwarg_encoder_signature_filtering(self): |
|
|
|
|
|
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
article = """Hugging Face is a technology company based in New York and Paris.""" |
|
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( |
|
|
torch_device |
|
|
) |
|
|
output = bart_model.generate(input_ids).cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FakeBart(BartForConditionalGeneration): |
|
|
def forward(self, input_ids, foo=None, **kwargs): |
|
|
return super().forward(input_ids, **kwargs) |
|
|
|
|
|
bart_model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device) |
|
|
fake_output = bart_model.generate(input_ids, foo="bar").cpu().numpy() |
|
|
self.assertTrue(np.array_equal(output, fake_output)) |
|
|
|
|
|
|
|
|
|
|
|
class FakeEncoder(bart_model.model.encoder.__class__): |
|
|
def forward(self, input_ids, **kwargs): |
|
|
return super().forward(input_ids, **kwargs) |
|
|
|
|
|
fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared).to(torch_device) |
|
|
bart_model.model.encoder = fake_encoder |
|
|
|
|
|
|
|
|
fake_output = bart_model.generate(input_ids).cpu().numpy() |
|
|
with self.assertRaises(TypeError): |
|
|
|
|
|
bart_model.generate(input_ids, foo="bar") |
|
|
|
|
|
def test_default_max_length_warning(self): |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
model.generation_config.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
text = "Hello world" |
|
|
tokenized_inputs = tokenizer([text], return_tensors="pt") |
|
|
input_ids = tokenized_inputs.input_ids.to(torch_device) |
|
|
|
|
|
|
|
|
with self.assertWarns(UserWarning): |
|
|
model.generate(input_ids) |
|
|
|
|
|
|
|
|
with warnings.catch_warnings(record=True) as warning_list: |
|
|
model.generate(input_ids, max_length=20) |
|
|
self.assertEqual(len(warning_list), 0) |
|
|
|
|
|
|
|
|
with warnings.catch_warnings(record=True) as warning_list: |
|
|
|
|
|
model.generation_config.max_length = 10 |
|
|
model.generate(input_ids) |
|
|
self.assertEqual(len(warning_list), 0) |
|
|
|
|
|
def test_length_warning_assisted_generation(self): |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
model.generation_config.pad_token_id = tokenizer.eos_token_id |
|
|
assistant.generation_config.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
text = "Hello world" |
|
|
tokenized_inputs = tokenizer([text], return_tensors="pt") |
|
|
input_ids = tokenized_inputs.input_ids.to(torch_device) |
|
|
|
|
|
|
|
|
with warnings.catch_warnings(record=True) as warning_list: |
|
|
model.generate( |
|
|
input_ids, |
|
|
assistant_model=assistant, |
|
|
min_new_tokens=10, |
|
|
max_length=20, |
|
|
) |
|
|
self.assertEqual(len(warning_list), 0) |
|
|
|
|
|
def test_default_assisted_generation(self): |
|
|
|
|
|
config = GenerationConfig() |
|
|
|
|
|
|
|
|
self.assertEqual(config.num_assistant_tokens, 20) |
|
|
self.assertEqual(config.num_assistant_tokens_schedule, "constant") |
|
|
self.assertEqual(config.assistant_confidence_threshold, 0.4) |
|
|
self.assertEqual(config.is_assistant, False) |
|
|
|
|
|
def test_generated_length_assisted_generation(self): |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
model.generation_config.pad_token_id = tokenizer.eos_token_id |
|
|
assistant.generation_config.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
text = "Hello world" |
|
|
tokenized_inputs = tokenizer([text], return_tensors="pt") |
|
|
input_ids = tokenized_inputs.input_ids.to(torch_device) |
|
|
input_length = input_ids.shape[-1] |
|
|
|
|
|
out = model.generate( |
|
|
input_ids, |
|
|
assistant_model=assistant, |
|
|
min_new_tokens=10, |
|
|
max_new_tokens=20, |
|
|
) |
|
|
self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length)) |
|
|
|
|
|
out = model.generate( |
|
|
input_ids, |
|
|
assistant_model=assistant, |
|
|
min_new_tokens=10, |
|
|
) |
|
|
self.assertTrue((input_length + 10) <= out.shape[-1]) |
|
|
|
|
|
out = model.generate( |
|
|
input_ids, |
|
|
assistant_model=assistant, |
|
|
max_new_tokens=7, |
|
|
) |
|
|
self.assertTrue(out.shape[-1] <= (input_length + 7)) |
|
|
|
|
|
def test_model_kwarg_assisted_decoding_decoder_only(self): |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
model.generation_config.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
text = "Hello world" |
|
|
tokenized_inputs = tokenizer([text], return_tensors="pt") |
|
|
input_ids = tokenized_inputs.input_ids.to(torch_device) |
|
|
|
|
|
|
|
|
outputs_normal = model.generate(input_ids) |
|
|
self.assertEqual(outputs_normal.shape, (1, 20)) |
|
|
|
|
|
|
|
|
outputs_tti = model.generate( |
|
|
input_ids, |
|
|
token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), |
|
|
) |
|
|
with self.assertRaises(AssertionError): |
|
|
self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist()) |
|
|
|
|
|
|
|
|
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
assistant.config.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
outputs_assisted = model.generate( |
|
|
input_ids, |
|
|
token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), |
|
|
assistant_model=assistant, |
|
|
) |
|
|
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist()) |
|
|
|
|
|
def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self): |
|
|
|
|
|
|
|
|
prompt = "Alice and Bob" |
|
|
checkpoint = "EleutherAI/pythia-160m-deduped" |
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(checkpoint) |
|
|
|
|
|
assistant_model = model |
|
|
assistant_model.generation_config.num_assistant_tokens = 5 |
|
|
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic" |
|
|
generation_kwargs = { |
|
|
"eos_token_id": -1, |
|
|
"max_new_tokens": 5, |
|
|
"do_sample": False, |
|
|
"assistant_model": assistant_model, |
|
|
} |
|
|
model.generate(**inputs, **generation_kwargs) |
|
|
|
|
|
self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7)) |
|
|
|
|
|
def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(self): |
|
|
|
|
|
|
|
|
prompt = "Alice and Bob" |
|
|
checkpoint = "EleutherAI/pythia-160m-deduped" |
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(checkpoint) |
|
|
|
|
|
assistant_model = model |
|
|
assistant_model.generation_config.num_assistant_tokens = 5 |
|
|
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic_transient" |
|
|
generation_kwargs = { |
|
|
"eos_token_id": -1, |
|
|
"max_new_tokens": 5, |
|
|
"do_sample": False, |
|
|
"assistant_model": assistant_model, |
|
|
} |
|
|
model.generate(**inputs, **generation_kwargs) |
|
|
|
|
|
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5) |
|
|
|
|
|
@slow |
|
|
def test_validate_assistant(self): |
|
|
|
|
|
inputs = np.random.rand(160000) |
|
|
|
|
|
|
|
|
model_id = "openai/whisper-large-v2" |
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
model_id, |
|
|
low_cpu_mem_usage=True, |
|
|
use_safetensors=True, |
|
|
) |
|
|
model.to(torch_device) |
|
|
|
|
|
|
|
|
features = processor(inputs, return_tensors="pt").to(torch_device) |
|
|
|
|
|
|
|
|
assistant_distil_model_id = "distil-whisper/distil-large-v2" |
|
|
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
assistant_distil_model_id, |
|
|
use_safetensors=True, |
|
|
).to(torch_device) |
|
|
self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum()) |
|
|
|
|
|
|
|
|
assistant_causal_lm = AutoModelForCausalLM.from_pretrained( |
|
|
assistant_distil_model_id, |
|
|
low_cpu_mem_usage=True, |
|
|
use_safetensors=True, |
|
|
).to(torch_device) |
|
|
self.assertTrue(model.generate(**features, assistant_model=assistant_causal_lm).sum()) |
|
|
|
|
|
|
|
|
assistant_distil_model_id = "openai/whisper-tiny" |
|
|
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
assistant_distil_model_id, |
|
|
use_safetensors=True, |
|
|
).to(torch_device) |
|
|
self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum()) |
|
|
|
|
|
|
|
|
assistant_causal_lm = AutoModelForCausalLM.from_pretrained( |
|
|
assistant_distil_model_id, |
|
|
low_cpu_mem_usage=True, |
|
|
use_safetensors=True, |
|
|
).to(torch_device) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
model.generate(**features, assistant_model=assistant_causal_lm) |
|
|
|
|
|
|
|
|
assistant_distil_model_id = "hf-internal-testing/tiny-random-SeamlessM4Tv2ForSpeechToText" |
|
|
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
assistant_distil_model_id, |
|
|
).to(torch_device) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
model.generate(**features, assistant_model=assistant_seq_to_seq) |
|
|
|
|
|
def test_compare_unprocessed_logit_scores(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
text = "generate yes or no: " |
|
|
input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
logits_fwd = model(input_ids).logits[:, -1, :][0] |
|
|
|
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
return_dict_in_generate=True, |
|
|
output_logits=True, |
|
|
max_new_tokens=1, |
|
|
do_sample=True, |
|
|
) |
|
|
logits_gen = outputs.logits[0][0] |
|
|
|
|
|
|
|
|
self.assertListEqual(logits_fwd.tolist(), logits_gen.tolist()) |
|
|
|
|
|
def test_return_unprocessed_logit_scores(self): |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
text = "generate yes or no: " |
|
|
input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device) |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, return_dict_in_generate=True, output_logits=True, max_new_tokens=3 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
probs_all = torch.nn.functional.softmax(outputs.logits[2][0], dim=-1) |
|
|
indices = torch.argwhere(probs_all > 0.001) |
|
|
indices = indices[:, -1] |
|
|
tokens_max = tokenizer.batch_decode(indices, skip_special_tokens=True) |
|
|
probs_max = probs_all[probs_all > 0.001] |
|
|
|
|
|
self.assertTrue(len(indices) >= 2) |
|
|
next_token_dict = {str(t): p for t, p in zip(tokens_max, probs_max)} |
|
|
self.assertTrue("n" in next_token_dict) |
|
|
self.assertTrue("y" in next_token_dict) |
|
|
y_prob = next_token_dict["y"] |
|
|
n_prob = next_token_dict["n"] |
|
|
|
|
|
self.assertTrue(y_prob > 0.001 and n_prob > 0.001) |
|
|
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) |
|
|
|
|
|
@slow |
|
|
@require_torch_multi_accelerator |
|
|
def test_assisted_decoding_in_different_gpu(self): |
|
|
device_0 = f"{torch_device}:0" if torch_device != "cpu" else "cpu" |
|
|
device_1 = f"{torch_device}:1" if torch_device != "cpu" else "cpu" |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(device_0) |
|
|
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( |
|
|
device_1 |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") |
|
|
model.config.pad_token_id = tokenizer.eos_token_id |
|
|
assistant.config.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
text = "Hello world" |
|
|
tokenized_inputs = tokenizer([text], return_tensors="pt") |
|
|
input_ids = tokenized_inputs.input_ids.to(torch_device) |
|
|
input_length = input_ids.shape[-1] |
|
|
|
|
|
out = model.generate( |
|
|
input_ids, |
|
|
assistant_model=assistant, |
|
|
max_new_tokens=20, |
|
|
) |
|
|
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) |
|
|
|
|
|
@slow |
|
|
@require_torch_accelerator |
|
|
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self): |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( |
|
|
torch_device |
|
|
) |
|
|
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( |
|
|
"cpu" |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") |
|
|
model.config.pad_token_id = tokenizer.eos_token_id |
|
|
assistant.config.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
text = "Hello world" |
|
|
tokenized_inputs = tokenizer([text], return_tensors="pt") |
|
|
input_ids = tokenized_inputs.input_ids.to(torch_device) |
|
|
input_length = input_ids.shape[-1] |
|
|
|
|
|
out = model.generate( |
|
|
input_ids, |
|
|
assistant_model=assistant, |
|
|
max_new_tokens=20, |
|
|
) |
|
|
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) |
|
|
|
|
|
def test_special_tokens_fall_back_to_model_default(self): |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( |
|
|
torch_device |
|
|
) |
|
|
test_bos_id = 50 |
|
|
|
|
|
|
|
|
gen_output = model.generate() |
|
|
self.assertTrue(model.generation_config.bos_token_id is not None) |
|
|
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) |
|
|
|
|
|
|
|
|
generation_config = GenerationConfig(bos_token_id=test_bos_id) |
|
|
gen_output = model.generate(generation_config=generation_config) |
|
|
self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0]) |
|
|
self.assertTrue(generation_config.bos_token_id == gen_output[0, 0]) |
|
|
self.assertTrue(test_bos_id == gen_output[0, 0]) |
|
|
|
|
|
|
|
|
|
|
|
generation_config = GenerationConfig(bos_token_id=None) |
|
|
gen_output = model.generate(generation_config=generation_config) |
|
|
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) |
|
|
self.assertFalse(test_bos_id == gen_output[0, 0]) |
|
|
self.assertTrue(generation_config.bos_token_id is None) |
|
|
|
|
|
|
|
|
model.generation_config.bos_token_id = test_bos_id |
|
|
gen_output = model.generate(generation_config=generation_config) |
|
|
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) |
|
|
self.assertTrue(test_bos_id == gen_output[0, 0]) |
|
|
self.assertTrue(generation_config.bos_token_id is None) |
|
|
|
|
|
def test_speculative_decoding_equals_regular_decoding(self): |
|
|
draft_name = "double7/vicuna-68m" |
|
|
target_name = "Qwen/Qwen2-0.5B-Instruct" |
|
|
|
|
|
draft_model = AutoModelForCausalLM.from_pretrained(draft_name) |
|
|
target_model = AutoModelForCausalLM.from_pretrained(target_name) |
|
|
|
|
|
assistant_tokenizer = AutoTokenizer.from_pretrained(draft_name) |
|
|
target_tokenizer = AutoTokenizer.from_pretrained(target_name) |
|
|
|
|
|
prompt_size = torch.randint(low=20, high=100, size=(1,)) |
|
|
max_new_tokens = torch.randint(low=10, high=50, size=(1,)) |
|
|
input_ids = (torch.rand(1, prompt_size[0]) * 100).to(int) + 50 |
|
|
|
|
|
max_new_tokens_item = max_new_tokens[0].item() |
|
|
expected_out = target_model.generate(input_ids, do_sample=False, max_new_tokens=max_new_tokens_item) |
|
|
predicted_out = target_model.generate( |
|
|
input_ids, |
|
|
do_sample=False, |
|
|
max_new_tokens=max_new_tokens_item, |
|
|
assistant_model=draft_model, |
|
|
tokenizer=target_tokenizer, |
|
|
assistant_tokenizer=assistant_tokenizer, |
|
|
) |
|
|
|
|
|
self.assertEqual(expected_out.shape, predicted_out.shape) |
|
|
self.assertTrue((expected_out == predicted_out).all().item()) |
|
|
|
|
|
@pytest.mark.generate |
|
|
@require_torch_multi_gpu |
|
|
def test_generate_with_static_cache_multi_gpu(self): |
|
|
""" |
|
|
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus. |
|
|
""" |
|
|
|
|
|
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0} |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") |
|
|
|
|
|
text = "Hello world" |
|
|
tokenized_inputs = tokenizer([text], return_tensors="pt") |
|
|
input_ids = tokenized_inputs.input_ids.to(torch_device) |
|
|
|
|
|
generation_kwargs = { |
|
|
"max_new_tokens": 20, |
|
|
"cache_implementation": "static", |
|
|
"return_dict_in_generate": True, |
|
|
} |
|
|
|
|
|
results = model.generate(input_ids, **generation_kwargs) |
|
|
self.assertTrue(isinstance(results.past_key_values, StaticCache)) |
|
|
|
|
|
|
|
|
key_cache_0 = results.past_key_values.key_cache[0] |
|
|
value_cache_0 = results.past_key_values.value_cache[0] |
|
|
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) |
|
|
|
|
|
key_cache_1 = results.past_key_values.key_cache[1] |
|
|
value_cache_1 = results.past_key_values.value_cache[1] |
|
|
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) |
|
|
|
|
|
@pytest.mark.generate |
|
|
@require_torch_multi_gpu |
|
|
def test_init_static_cache_multi_gpu(self): |
|
|
""" |
|
|
Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup. |
|
|
""" |
|
|
|
|
|
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0} |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") |
|
|
|
|
|
text = "Hello world" |
|
|
tokenized_inputs = tokenizer([text], return_tensors="pt") |
|
|
input_ids = tokenized_inputs.input_ids.to(torch_device) |
|
|
|
|
|
generation_kwargs = { |
|
|
"max_new_tokens": 20, |
|
|
"return_dict_in_generate": True, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_device_map = {0: 0, 1: 1} |
|
|
past_key_values = StaticCache( |
|
|
config=model.config, |
|
|
max_batch_size=1, |
|
|
max_cache_len=30, |
|
|
device=torch_device, |
|
|
dtype=model.dtype, |
|
|
layer_device_map=layer_device_map, |
|
|
) |
|
|
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) |
|
|
|
|
|
|
|
|
key_cache_0 = results.past_key_values.key_cache[0] |
|
|
value_cache_0 = results.past_key_values.value_cache[0] |
|
|
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) |
|
|
|
|
|
key_cache_1 = results.past_key_values.key_cache[1] |
|
|
value_cache_1 = results.past_key_values.value_cache[1] |
|
|
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) |
|
|
|
|
|
@slow |
|
|
def test_padding_input_contrastive_search_gpt2(self): |
|
|
|
|
|
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2") |
|
|
model.to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True) |
|
|
|
|
|
|
|
|
tokenizer.padding_side = "left" |
|
|
|
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
model.generation_config.pad_token_id = model.generation_config.eos_token_id |
|
|
|
|
|
|
|
|
prompt_text = "The whispered legends of the haunted mansion spoke" |
|
|
|
|
|
|
|
|
encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True) |
|
|
input_ids = encoded_prompt.input_ids.to(torch_device) |
|
|
attention_mask = encoded_prompt.attention_mask.to(torch_device) |
|
|
|
|
|
|
|
|
penalty_alpha = 0.6 |
|
|
top_k = 4 |
|
|
|
|
|
|
|
|
padding_length = 10 |
|
|
|
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
do_sample=False, |
|
|
penalty_alpha=penalty_alpha, |
|
|
top_k=top_k, |
|
|
max_new_tokens=64, |
|
|
) |
|
|
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
padded_input_ids = F.pad( |
|
|
input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id |
|
|
) |
|
|
padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0) |
|
|
|
|
|
|
|
|
outputs_with_padding = model.generate( |
|
|
input_ids=padded_input_ids, |
|
|
attention_mask=padded_attention_mask, |
|
|
do_sample=False, |
|
|
penalty_alpha=penalty_alpha, |
|
|
top_k=top_k, |
|
|
max_new_tokens=64, |
|
|
) |
|
|
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
self.assertEqual(generated_text_no_padding, generated_text_with_padding) |
|
|
self.assertEqual( |
|
|
generated_text_with_padding, |
|
|
'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling ' |
|
|
'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been ' |
|
|
'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea', |
|
|
) |
|
|
|
|
|
@slow |
|
|
def test_padding_input_contrastive_search_t5(self): |
|
|
|
|
|
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small") |
|
|
model.to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True) |
|
|
|
|
|
|
|
|
prompt_text = "translate English to German: I need to finish this task before the end of the day." |
|
|
|
|
|
|
|
|
encoded_prompt = tokenizer(prompt_text, return_tensors="pt") |
|
|
input_ids = encoded_prompt.input_ids.to(torch_device) |
|
|
attention_mask = encoded_prompt.attention_mask.to(torch_device) |
|
|
|
|
|
|
|
|
decoder_prompt_text = "Ich muss diese Aufgabe" |
|
|
encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt") |
|
|
decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device) |
|
|
decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device) |
|
|
|
|
|
|
|
|
penalty_alpha = 0.6 |
|
|
top_k = 4 |
|
|
|
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
decoder_attention_mask=decoder_attention_mask, |
|
|
do_sample=False, |
|
|
penalty_alpha=penalty_alpha, |
|
|
top_k=top_k, |
|
|
max_new_tokens=64, |
|
|
) |
|
|
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
padding_length = 10 |
|
|
|
|
|
|
|
|
padded_decoder_input_ids = F.pad( |
|
|
decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id |
|
|
) |
|
|
padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0) |
|
|
|
|
|
|
|
|
|
|
|
padded_decoder_attention_mask[:, padding_length - 1] = 1 |
|
|
|
|
|
outputs_with_padding = model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
decoder_input_ids=padded_decoder_input_ids, |
|
|
decoder_attention_mask=padded_decoder_attention_mask, |
|
|
do_sample=False, |
|
|
penalty_alpha=penalty_alpha, |
|
|
top_k=top_k, |
|
|
max_new_tokens=64, |
|
|
) |
|
|
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
self.assertEqual(generated_text_no_padding, generated_text_with_padding) |
|
|
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.") |
|
|
|
|
|
def test_prepare_inputs_for_generation_decoder_llm(self): |
|
|
"""Tests GenerationMixin.prepare_inputs_for_generation against expected usage with decoder-only llms.""" |
|
|
|
|
|
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") |
|
|
model = model.to(torch_device) |
|
|
|
|
|
|
|
|
self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation)) |
|
|
|
|
|
|
|
|
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) |
|
|
model_inputs = model.prepare_inputs_for_generation(input_ids) |
|
|
self.assertTrue(torch.all(model_inputs["input_ids"] == input_ids)) |
|
|
|
|
|
|
|
|
attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device) |
|
|
model_inputs = model.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask) |
|
|
self.assertTrue(torch.all(model_inputs["attention_mask"] == attention_mask)) |
|
|
self.assertTrue(model_inputs["position_ids"].shape == input_ids.shape) |
|
|
|
|
|
|
|
|
self.assertFalse("use_cache" in model_inputs) |
|
|
model_inputs = model.prepare_inputs_for_generation(input_ids, use_cache=True, foo="bar") |
|
|
self.assertTrue(model_inputs["use_cache"] is True) |
|
|
self.assertTrue(model_inputs["foo"] == "bar") |
|
|
|
|
|
|
|
|
|
|
|
init_input_ids = input_ids[:, :2] |
|
|
dynamic_cache = DynamicCache() |
|
|
dynamic_cache = model(init_input_ids, past_key_values=dynamic_cache).past_key_values |
|
|
with self.assertRaises(AttributeError): |
|
|
model_inputs = model.prepare_inputs_for_generation(input_ids, past_key_values=dynamic_cache) |
|
|
|
|
|
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to(torch_device) |
|
|
cache_position = cache_position[dynamic_cache.get_seq_length() :] |
|
|
model_inputs = model.prepare_inputs_for_generation( |
|
|
input_ids, past_key_values=dynamic_cache, cache_position=cache_position, attention_mask=attention_mask |
|
|
) |
|
|
self.assertTrue("past_key_values" in model_inputs) |
|
|
self.assertTrue(torch.all(model_inputs["cache_position"] == cache_position)) |
|
|
self.assertTrue(model_inputs["input_ids"].shape[-1] == 1) |
|
|
self.assertTrue(model_inputs["position_ids"].shape[-1] == 1) |
|
|
self.assertTrue(model_inputs["attention_mask"].shape[-1] == 3) |
|
|
|
|
|
|
|
|
max_cache_len = 10 |
|
|
batch_size = 2 |
|
|
query_length = input_ids.shape[-1] - init_input_ids.shape[-1] |
|
|
static_cache = StaticCache( |
|
|
config=config, |
|
|
max_batch_size=batch_size, |
|
|
max_cache_len=max_cache_len, |
|
|
device=torch_device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
static_cache = model(init_input_ids, past_key_values=static_cache).past_key_values |
|
|
model_inputs = model.prepare_inputs_for_generation( |
|
|
input_ids, past_key_values=static_cache, cache_position=cache_position, attention_mask=attention_mask |
|
|
) |
|
|
self.assertTrue("past_key_values" in model_inputs) |
|
|
self.assertTrue(list(model_inputs["attention_mask"].shape) == [batch_size, 1, query_length, max_cache_len]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
init_inputs_embeds = model.get_input_embeddings()(init_input_ids) |
|
|
init_cache_positions = torch.arange(init_input_ids.shape[-1], dtype=torch.long).to(torch_device) |
|
|
empty_cache = DynamicCache() |
|
|
|
|
|
|
|
|
model_inputs = model.prepare_inputs_for_generation( |
|
|
init_input_ids, |
|
|
past_key_values=empty_cache, |
|
|
inputs_embeds=init_inputs_embeds, |
|
|
cache_position=init_cache_positions, |
|
|
) |
|
|
self.assertTrue(model_inputs["input_ids"] is None) |
|
|
self.assertTrue(model_inputs["inputs_embeds"] is not None) |
|
|
|
|
|
|
|
|
model_inputs = model.prepare_inputs_for_generation( |
|
|
input_ids, past_key_values=dynamic_cache, inputs_embeds=init_inputs_embeds, cache_position=cache_position |
|
|
) |
|
|
self.assertTrue(model_inputs["input_ids"] is not None) |
|
|
self.assertTrue(model_inputs["inputs_embeds"] is None) |
|
|
|
|
|
def test_prepare_inputs_for_generation_encoder_decoder_llm(self): |
|
|
""" |
|
|
Same as `test_prepare_inputs_for_generation_decoder_llm` but for encoder-decoder models. Main difference: we |
|
|
should look for `decoder_input_ids`, instead of `input_ids`. |
|
|
""" |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") |
|
|
model = model.to(torch_device) |
|
|
|
|
|
|
|
|
self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation)) |
|
|
|
|
|
|
|
|
decoder_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) |
|
|
model_inputs = model.prepare_inputs_for_generation(decoder_input_ids) |
|
|
self.assertTrue(torch.all(model_inputs["decoder_input_ids"] == decoder_input_ids)) |
|
|
|
|
|
|
|
|
|
|
|
decoder_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device) |
|
|
model_inputs = model.prepare_inputs_for_generation( |
|
|
decoder_input_ids, decoder_attention_mask=decoder_attention_mask |
|
|
) |
|
|
self.assertTrue(torch.all(model_inputs["decoder_attention_mask"] == decoder_attention_mask)) |
|
|
self.assertTrue("position_ids" not in model_inputs) |
|
|
|
|
|
|
|
|
self.assertFalse("use_cache" in model_inputs) |
|
|
model_inputs = model.prepare_inputs_for_generation(decoder_input_ids, use_cache=True, encoder_outputs="foo") |
|
|
self.assertTrue(model_inputs["use_cache"] is True) |
|
|
self.assertTrue(model_inputs["encoder_outputs"] == "foo") |
|
|
|
|
|
|
|
|
def test_generate_compile_fullgraph_tiny(self): |
|
|
""" |
|
|
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash) |
|
|
NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the |
|
|
non-slow tests to prevent regressions! |
|
|
""" |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto" |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") |
|
|
|
|
|
|
|
|
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") |
|
|
|
|
|
|
|
|
generation_config = copy.deepcopy(model.generation_config) |
|
|
generation_config.pad_token_id = model.config.eos_token_id |
|
|
|
|
|
model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt") |
|
|
model_inputs = model_inputs.to(model.device) |
|
|
gen_out = compiled_generate(**model_inputs, generation_config=generation_config) |
|
|
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) |
|
|
|
|
|
@require_read_token |
|
|
@slow |
|
|
def test_assisted_generation_early_exit(self): |
|
|
""" |
|
|
Tests that assisted generation with early exit works as expected. Under the hood, this has complex cache |
|
|
manipulation, which will cause the test to fail if something goes wrong there. |
|
|
""" |
|
|
expected_output = "Alice and Bob are playing a game of poker. Alice has a pair of 8s and Bob has a pair" |
|
|
|
|
|
prompt = "Alice and Bob" |
|
|
checkpoint = "facebook/layerskip-llama3.2-1B" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(torch_device) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(torch_device) |
|
|
original_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20) |
|
|
original_decoded = tokenizer.batch_decode(original_outputs, skip_special_tokens=True) |
|
|
self.assertEqual(original_decoded, [expected_output]) |
|
|
|
|
|
outputs_assisted = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20) |
|
|
decoded_assisted = tokenizer.batch_decode(outputs_assisted, skip_special_tokens=True) |
|
|
self.assertEqual(decoded_assisted, [expected_output]) |
|
|
|
|
|
@slow |
|
|
def test_beam_search_advanced_stopping_criteria(self): |
|
|
""" |
|
|
Tests that beam search works with a stopping criteria that is not max length or EOS token. Prior to the beam |
|
|
search vectorization PR (#35802), beam search was not accepting other stopping criteria. Test inspired on |
|
|
the original issue (#34843). |
|
|
""" |
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") |
|
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct").to(torch_device) |
|
|
|
|
|
prompt = ( |
|
|
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. " |
|
|
"How many clips did Natalia sell altogether in April and May?" |
|
|
) |
|
|
tokens = tokenizer(prompt, return_tensors="pt").to(torch_device) |
|
|
generation_config = GenerationConfig(num_beams=3, do_sample=False, length_penalty=1.0, max_new_tokens=100) |
|
|
|
|
|
|
|
|
out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer) |
|
|
output_text = tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1]) |
|
|
self.assertTrue(":" in output_text) |
|
|
self.assertFalse(":" in output_text[-5:]) |
|
|
self.assertFalse(":" in last_non_special_token_decoded) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generation_config.stop_strings = ":" |
|
|
out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer) |
|
|
output_text = tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1]) |
|
|
self.assertTrue(":" in output_text) |
|
|
self.assertTrue(":" in output_text[-5:]) |
|
|
self.assertTrue(":" in last_non_special_token_decoded) |
|
|
|
|
|
def test_max_time(self): |
|
|
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") |
|
|
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2") |
|
|
model.to(torch_device) |
|
|
|
|
|
torch.manual_seed(0) |
|
|
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) |
|
|
input_ids = tokenized.input_ids.to(torch_device) |
|
|
|
|
|
MAX_TIME = 0.1 |
|
|
MAX_LENGTH = 64 |
|
|
|
|
|
|
|
|
start = datetime.datetime.now() |
|
|
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=MAX_LENGTH) |
|
|
duration = datetime.datetime.now() - start |
|
|
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) |
|
|
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) |
|
|
|
|
|
|
|
|
start = datetime.datetime.now() |
|
|
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=MAX_LENGTH) |
|
|
duration = datetime.datetime.now() - start |
|
|
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) |
|
|
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) |
|
|
|
|
|
|
|
|
start = datetime.datetime.now() |
|
|
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=MAX_LENGTH) |
|
|
duration = datetime.datetime.now() - start |
|
|
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) |
|
|
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) |
|
|
|
|
|
|
|
|
start = datetime.datetime.now() |
|
|
model.generate(input_ids, do_sample=False, max_time=None, max_length=MAX_LENGTH) |
|
|
duration = datetime.datetime.now() - start |
|
|
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) |
|
|
|
|
|
def test_validate_generation_inputs(self): |
|
|
"""Tests validation of inputs to `generate`""" |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") |
|
|
|
|
|
encoder_input_str = "Hello world" |
|
|
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "do_samples"): |
|
|
model.generate(input_ids, do_samples=True) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "foo"): |
|
|
fake_model_kwargs = {"foo": "bar"} |
|
|
model.generate(input_ids, **fake_model_kwargs) |
|
|
|
|
|
|
|
|
valid_model_kwargs = {"attention_mask": torch.tensor(np.zeros_like(input_ids))} |
|
|
model.generate(input_ids, **valid_model_kwargs) |
|
|
|
|
|
def test_custom_logits_processor(self): |
|
|
"""Tests that custom logits processors can be used in `generate`, and that redundant arguments are caught.""" |
|
|
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
|
|
bart_model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-bart", min_length=1) |
|
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids |
|
|
|
|
|
logits_processor = LogitsProcessorList() |
|
|
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0)) |
|
|
|
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
bart_model.generate(input_ids, logits_processor=logits_processor, min_length=10) |
|
|
bart_model.generate(input_ids, logits_processor=logits_processor) |
|
|
|
|
|
def test_transition_scores_greedy_search(self): |
|
|
"""Test that `compute_transition_scores` is working as expected with gready search""" |
|
|
articles = ["Justin Timberlake", "Michael Phelps"] |
|
|
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2", padding_side="left") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2") |
|
|
model.generation_config.eos_token_id = None |
|
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids |
|
|
model = model.to(torch_device) |
|
|
input_ids = input_ids.to(torch_device) |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
max_new_tokens=5, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
return_dict_in_generate=True, |
|
|
output_scores=True, |
|
|
) |
|
|
|
|
|
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores) |
|
|
transition_scores = transition_scores.cpu().numpy() |
|
|
|
|
|
expected_scores = np.array( |
|
|
[ |
|
|
[-57.8844, -60.45698, -70.16364, -65.50791, -66.35648], |
|
|
[-54.417572, -60.216614, -62.661243, -58.621933, -58.298683], |
|
|
] |
|
|
) |
|
|
self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3)) |
|
|
|
|
|
def test_transition_scores_greedy_search_normalized(self): |
|
|
""" |
|
|
Test that `compute_transition_scores` is working as expected with gready search, with `normalize_logits=True` |
|
|
""" |
|
|
articles = ["Justin Timberlake", "Michael Phelps"] |
|
|
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2", padding_side="left") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2") |
|
|
model.generation_config.eos_token_id = None |
|
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids |
|
|
model = model.to(torch_device) |
|
|
input_ids = input_ids.to(torch_device) |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
max_new_tokens=5, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
return_dict_in_generate=True, |
|
|
output_scores=True, |
|
|
) |
|
|
|
|
|
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True) |
|
|
transition_scores = transition_scores.cpu().numpy() |
|
|
|
|
|
expected_scores = np.array( |
|
|
[ |
|
|
[-2.538938, -2.2694316, -2.1580915, -1.572299, -2.6719835], |
|
|
[-1.8826028, -2.2461371, -1.7556462, -2.9644494, -1.7996008], |
|
|
] |
|
|
) |
|
|
self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3)) |
|
|
|
|
|
def test_transition_scores_beam_search_encoder_decoder(self): |
|
|
""" |
|
|
Test that `compute_transition_scores` is working as expected with beam search and encoder-decoder models |
|
|
""" |
|
|
articles = [ |
|
|
"Justin Timberlake and Jessica Biel, welcome to parenthood.", |
|
|
"Michael Phelps is arguably the most decorated Olympian of all time.", |
|
|
] |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids |
|
|
model = model.to(torch_device) |
|
|
input_ids = input_ids.to(torch_device) |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
max_length=10, |
|
|
num_beams=4, |
|
|
num_return_sequences=2, |
|
|
eos_token_id=None, |
|
|
return_dict_in_generate=True, |
|
|
output_scores=True, |
|
|
length_penalty=0.0, |
|
|
) |
|
|
|
|
|
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices) |
|
|
transition_scores = transition_scores.cpu().numpy() |
|
|
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy() |
|
|
|
|
|
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) |
|
|
|
|
|
def test_transition_scores_beam_search_encoder_decoder_with_eos(self): |
|
|
""" |
|
|
Test that `compute_transition_scores` is working as expected with beam search and encoder-decoder models, when |
|
|
an EOS token is defined |
|
|
""" |
|
|
articles = [ |
|
|
"Justin Timberlake and Jessica Biel, welcome to parenthood.", |
|
|
"Michael Phelps is arguably the most decorated Olympian of all time.", |
|
|
] |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids |
|
|
model = model.to(torch_device) |
|
|
input_ids = input_ids.to(torch_device) |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
max_length=10, |
|
|
num_beams=4, |
|
|
num_return_sequences=2, |
|
|
return_dict_in_generate=True, |
|
|
output_scores=True, |
|
|
length_penalty=0.0, |
|
|
) |
|
|
|
|
|
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices) |
|
|
transition_scores = transition_scores.cpu().numpy() |
|
|
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy() |
|
|
|
|
|
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) |
|
|
|
|
|
def test_transition_scores_beam_search_decoder_only(self): |
|
|
""" |
|
|
Test that `compute_transition_scores` is working as expected with beam search and decoder-only models |
|
|
""" |
|
|
articles = [ |
|
|
"Justin Timberlake", |
|
|
"Michael Phelps", |
|
|
] |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids |
|
|
model = model.to(torch_device) |
|
|
input_ids = input_ids.to(torch_device) |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
max_length=10, |
|
|
num_beams=4, |
|
|
num_return_sequences=2, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=None, |
|
|
return_dict_in_generate=True, |
|
|
output_scores=True, |
|
|
length_penalty=0.0, |
|
|
) |
|
|
|
|
|
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices) |
|
|
transition_scores = transition_scores.cpu().numpy() |
|
|
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy() |
|
|
|
|
|
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) |
|
|
|
|
|
@slow |
|
|
def test_transition_scores_early_stopping(self): |
|
|
""" |
|
|
Test that `compute_transition_scores` is working as expected with beam search and early stopping |
|
|
|
|
|
This is an aggressive test that makes sure that `beam_search's` |
|
|
transition scores are computed correctly for varying `num_return_sequences`, `num_beams` and `batch_size > 1` |
|
|
2 x input_ids for "question: How are you? \n context: I had a long day, " |
|
|
""" |
|
|
input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small") |
|
|
model = model.to(torch_device) |
|
|
input_ids = input_ids.to(torch_device) |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids, |
|
|
max_length=10, |
|
|
return_dict_in_generate=True, |
|
|
output_scores=True, |
|
|
forced_eos_token_id=model.config.eos_token_id, |
|
|
num_beams=4, |
|
|
do_sample=False, |
|
|
num_return_sequences=3, |
|
|
length_penalty=0.0, |
|
|
) |
|
|
|
|
|
transition_scores = model.compute_transition_scores( |
|
|
sequences=outputs.sequences, scores=outputs.scores, beam_indices=outputs.beam_indices |
|
|
) |
|
|
transition_scores = transition_scores.cpu().numpy() |
|
|
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy() |
|
|
|
|
|
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores)) |
|
|
|
|
|
def test_encoder_decoder_generate_attention_mask(self): |
|
|
""" |
|
|
Test that `generate` automagically creates the correct `attention_mask` for encoder-decoder models (which |
|
|
has a different keyword) |
|
|
""" |
|
|
articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"] |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
|
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-bart", |
|
|
) |
|
|
model.config.eos_token_id = None |
|
|
input_ids = tokenizer(articles[0], return_tensors="pt").input_ids |
|
|
input_ids_batched = tokenizer(articles, padding=True, return_tensors="pt").input_ids |
|
|
model = model.to(torch_device) |
|
|
input_ids = input_ids.to(torch_device) |
|
|
input_ids_batched = input_ids_batched.to(torch_device) |
|
|
|
|
|
generate_kwargs = { |
|
|
"return_dict_in_generate": True, |
|
|
"output_scores": True, |
|
|
"max_length": 50, |
|
|
"num_beams": 5, |
|
|
"num_return_sequences": 5, |
|
|
} |
|
|
|
|
|
output_sequences_batched = model.generate(input_ids=input_ids_batched, **generate_kwargs) |
|
|
output_sequences = model.generate(input_ids=input_ids, **generate_kwargs) |
|
|
|
|
|
batched_out = output_sequences_batched.sequences_scores |
|
|
out = output_sequences.sequences_scores |
|
|
batched_out = batched_out.cpu().numpy() |
|
|
out = out.cpu().numpy() |
|
|
|
|
|
diff = np.abs(np.sum(batched_out[:5]) - np.sum(out)) |
|
|
self.assertTrue(diff < 1e-4) |
|
|
|
|
|
def test_generate_input_ids_as_kwarg(self): |
|
|
"""Test that `input_ids` work equally as a positional and keyword argument in decoder-only models""" |
|
|
article = "I need input_ids to generate" |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15) |
|
|
input_ids = tokenizer(article, return_tensors="pt").input_ids |
|
|
model = model.to(torch_device) |
|
|
input_ids = input_ids.to(torch_device) |
|
|
|
|
|
output_sequences_kwargs = model.generate(input_ids=input_ids) |
|
|
output_sequences = model.generate(input_ids) |
|
|
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy() |
|
|
output_sequences = output_sequences.cpu().numpy() |
|
|
|
|
|
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs)) |
|
|
self.assertEqual(output_sequences.shape, (1, 15)) |
|
|
|
|
|
def test_generate_input_ids_as_encoder_kwarg(self): |
|
|
"""Test that `input_ids` work equally as a positional and keyword argument in encoder-decoder models""" |
|
|
article = "Justin Timberlake and Jessica Biel, welcome to parenthood." |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
model.config.eos_token_id = None |
|
|
input_ids = tokenizer(article, return_tensors="pt").input_ids |
|
|
model = model.to(torch_device) |
|
|
input_ids = input_ids.to(torch_device) |
|
|
|
|
|
output_sequences_kwargs = model.generate(input_ids=input_ids, max_length=5) |
|
|
output_sequences = model.generate(input_ids, max_length=5) |
|
|
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy() |
|
|
output_sequences = output_sequences.cpu().numpy() |
|
|
|
|
|
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs)) |
|
|
self.assertEqual(output_sequences.shape, (1, 5)) |
|
|
|
|
|
def test_generate_inputs_and_encoder_kwargs(self): |
|
|
""" |
|
|
Test that an exception is thrown if the main tensor (`input_ids` in LLMs) is passed as both a positional and |
|
|
keyword argument |
|
|
""" |
|
|
article = "I need input_ids to generate" |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10) |
|
|
input_ids = tokenizer(article, return_tensors="pt").input_ids |
|
|
with self.assertRaises(ValueError): |
|
|
model.generate(input_ids, input_ids=input_ids) |
|
|
|
|
|
def test_generate_too_many_encoder_kwargs(self): |
|
|
"""Test that passing redundant inputs results in an exception (`input_ids` and `inputs_embeds` in LLMs)""" |
|
|
article = "I need input_ids to generate" |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=10) |
|
|
input_ids = tokenizer(article, return_tensors="pt").input_ids |
|
|
with self.assertRaises(ValueError): |
|
|
model.generate(input_ids=input_ids, inputs_embeds=input_ids) |
|
|
|
|
|
def test_generate_input_features_as_encoder_kwarg(self): |
|
|
"""Test that non-`input_ids` main model inputs are correctly handled as positional arguments""" |
|
|
input_features = floats_tensor((3, 80, 60)) |
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-WhisperForConditionalGeneration" |
|
|
) |
|
|
input_features.to(torch_device) |
|
|
model = model.to(torch_device) |
|
|
|
|
|
output_sequences_kwargs = model.generate(input_features=input_features, max_length=5) |
|
|
output_sequences = model.generate(input_features, max_length=5) |
|
|
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy() |
|
|
output_sequences = output_sequences.cpu().numpy() |
|
|
|
|
|
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs)) |
|
|
self.assertEqual(output_sequences.shape, (3, 5)) |
|
|
|
|
|
def test_generate_encoder_outputs_attention_mask(self): |
|
|
"""Test that `generate` can handle attention masks when the encoder outputs are passed""" |
|
|
input_features = floats_tensor((3, 80, 60)) |
|
|
attention_mask = torch.randint(0, 2, input_features.shape).to(torch_device) |
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-WhisperForConditionalGeneration" |
|
|
) |
|
|
input_features = input_features.to(torch_device) |
|
|
attention_mask = attention_mask.to(torch_device) |
|
|
model = model.to(torch_device) |
|
|
|
|
|
encoder = model.get_encoder() |
|
|
encoder_outputs = encoder(input_features) |
|
|
|
|
|
output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs) |
|
|
output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask) |
|
|
output_sequences_no_mask = output_sequences_no_mask.cpu().numpy() |
|
|
output_sequences_with_mask = output_sequences_with_mask.cpu().numpy() |
|
|
|
|
|
self.assertFalse(np.array_equal(output_sequences_no_mask, output_sequences_with_mask)) |
|
|
|
|
|
def test_eos_token_id_int_and_list_greedy_search(self): |
|
|
"""Test that `generate` can handle multiple EOS tokens""" |
|
|
generation_kwargs = { |
|
|
"do_sample": False, |
|
|
"num_beams": 1, |
|
|
} |
|
|
expectation = 13 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
text = """Hello, my dog is cute and""" |
|
|
tokens = tokenizer(text, return_tensors="pt") |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
model = model.to(torch_device) |
|
|
tokens = tokens.to(torch_device) |
|
|
|
|
|
eos_token_id = 873 |
|
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) |
|
|
self.assertTrue(expectation == len(generated_tokens[0])) |
|
|
|
|
|
eos_token_id = [873, 198] |
|
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) |
|
|
self.assertTrue(expectation == len(generated_tokens[0])) |
|
|
|
|
|
def test_generate_vision2text_conditioning(self): |
|
|
"""Test that `decoder_input_ids` can be used to condition the generation in vision-to-text models""" |
|
|
pixel_values = floats_tensor((2, 3, 30, 30)) |
|
|
conditioning_input = torch.tensor([[10], [10]]) |
|
|
model = AutoModelForVision2Seq.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2" |
|
|
) |
|
|
pixel_values = pixel_values.to(torch_device) |
|
|
model = model.to(torch_device) |
|
|
conditioning_input = conditioning_input.to(torch_device) |
|
|
|
|
|
|
|
|
|
|
|
output_sequences_decoder_input_ids = model.generate( |
|
|
pixel_values, max_length=5, decoder_input_ids=conditioning_input |
|
|
) |
|
|
output_sequences_input_ids = model.generate(pixel_values, max_length=5, input_ids=conditioning_input) |
|
|
output_sequences_decoder_input_ids = output_sequences_decoder_input_ids.cpu().numpy() |
|
|
output_sequences_input_ids = output_sequences_input_ids.cpu().numpy() |
|
|
conditioning_input = conditioning_input.cpu().numpy() |
|
|
|
|
|
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids)) |
|
|
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input)) |
|
|
|
|
|
@require_read_token |
|
|
@slow |
|
|
@require_torch_gpu |
|
|
def test_cache_device_map_with_vision_layer_device_map(self): |
|
|
""" |
|
|
Test that the cache device map is correctly set when the vision layer has a device map. Regression test for |
|
|
#36942 |
|
|
""" |
|
|
|
|
|
model_id = "google/gemma-3-4b-it" |
|
|
|
|
|
|
|
|
device_map = { |
|
|
"vision_tower.vision_model.embeddings": 0, |
|
|
"vision_tower.vision_model.encoder.layers.0": 0, |
|
|
"vision_tower.vision_model.encoder.layers.1": 0, |
|
|
"vision_tower.vision_model.encoder.layers.2": 0, |
|
|
"vision_tower.vision_model.encoder.layers.3": 0, |
|
|
"vision_tower.vision_model.encoder.layers.4": 0, |
|
|
"vision_tower.vision_model.encoder.layers.5": 0, |
|
|
"vision_tower.vision_model.encoder.layers.6": 0, |
|
|
"vision_tower.vision_model.encoder.layers.7": 0, |
|
|
"vision_tower.vision_model.encoder.layers.8": 0, |
|
|
"vision_tower.vision_model.encoder.layers.9": 0, |
|
|
"vision_tower.vision_model.encoder.layers.10": 0, |
|
|
"vision_tower.vision_model.encoder.layers.11": 0, |
|
|
"vision_tower.vision_model.encoder.layers.12": 0, |
|
|
"vision_tower.vision_model.encoder.layers.13": 0, |
|
|
"vision_tower.vision_model.encoder.layers.14": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.15": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.16": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.17": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.18": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.19": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.20": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.21": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.22": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.23": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.24": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.25": "cpu", |
|
|
"vision_tower.vision_model.encoder.layers.26": "cpu", |
|
|
"vision_tower.vision_model.post_layernorm": "cpu", |
|
|
"multi_modal_projector": "cpu", |
|
|
"language_model": "cpu", |
|
|
} |
|
|
|
|
|
model = AutoModelForImageTextToText.from_pretrained( |
|
|
model_id, device_map=device_map, torch_dtype=torch.bfloat16 |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
inputs = tokenizer(["This is a text input"], return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False) |
|
|
|
|
|
@require_torch_gpu |
|
|
def test_cpu_offload_doesnt_compile(self): |
|
|
"""Test that CPU offload doesn't trigger compilation""" |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") |
|
|
tokenized_inputs = tokenizer(["Hello world"], return_tensors="pt") |
|
|
generate_kwargs = {"max_new_tokens": 3, "cache_implementation": "static"} |
|
|
|
|
|
|
|
|
model_gpu = AutoModelForCausalLM.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto" |
|
|
) |
|
|
input_ids = tokenized_inputs.input_ids.to(model_gpu.device) |
|
|
_ = model_gpu.generate(input_ids, **generate_kwargs) |
|
|
self.assertTrue(hasattr(model_gpu, "_compiled_call")) |
|
|
|
|
|
|
|
|
|
|
|
device_map = { |
|
|
"model.embed_tokens": 0, |
|
|
"model.layers.0": 0, |
|
|
"model.layers.1": "cpu", |
|
|
"model.norm": "cpu", |
|
|
"lm_head": 0, |
|
|
} |
|
|
model_cpu = AutoModelForCausalLM.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map |
|
|
) |
|
|
input_ids = tokenized_inputs.input_ids.to(model_cpu.device) |
|
|
_ = model_cpu.generate(input_ids, **generate_kwargs) |
|
|
self.assertFalse(hasattr(model_cpu, "_compiled_call")) |
|
|
|
|
|
|
|
|
@require_torch |
|
|
class TokenHealingTestCase(unittest.TestCase): |
|
|
@parameterized.expand( |
|
|
[ |
|
|
("url", 'The link is <a href="http:', 'The link is <a href="http://'), |
|
|
|
|
|
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'), |
|
|
("trailing_whitespace", "I read a book about ", "I read a book about"), |
|
|
("nothing_to_heal", "I read a book about", "I read a book about"), |
|
|
("single_token", "I", "I"), |
|
|
("empty_prompt", "", ""), |
|
|
] |
|
|
) |
|
|
def test_prompts(self, name, input, expected): |
|
|
model_name_or_path = "distilbert/distilgpt2" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) |
|
|
completion_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name_or_path, |
|
|
device_map="auto", |
|
|
trust_remote_code=False, |
|
|
revision="main", |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
""" |
|
|
tokenizer.pad_token value can be empty but it is required in the latter codes |
|
|
so assigned it here with eos_token |
|
|
""" |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device) |
|
|
|
|
|
healed_ids = completion_model.heal_tokens(input_ids, tokenizer=tokenizer) |
|
|
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True) |
|
|
|
|
|
self.assertEqual(predicted, expected) |
|
|
|
|
|
def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self): |
|
|
article = "Today a dragon flew over Paris." |
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
|
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
|
|
inputs_embeds = model.get_input_embeddings()(input_ids) |
|
|
|
|
|
model.generate(inputs_embeds=inputs_embeds, max_length=20, bos_token_id=None) |
|
|
|
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
model.generate(max_length=20, bos_token_id=None) |
|
|
|
|
|
|
|
|
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase): |
|
|
def test_no_intersection(self): |
|
|
prompt = np.array([[1, 2, 3]]) |
|
|
prompt_plus_new_tokens = np.array([[4, 5, 6]]) |
|
|
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens) |
|
|
self.assertEqual(result, (None, None, None)) |
|
|
|
|
|
def test_complete_overlap(self): |
|
|
prompt = np.array([[1, 2, 3]]) |
|
|
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]]) |
|
|
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( |
|
|
prompt, prompt_plus_new_tokens |
|
|
) |
|
|
self.assertEqual(discrep_length, 0) |
|
|
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) |
|
|
np.testing.assert_array_equal(discrep_only, np.array([[]])) |
|
|
|
|
|
def test_partial_overlap(self): |
|
|
prompt = np.array([[1, 2, 3]]) |
|
|
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]]) |
|
|
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( |
|
|
prompt, prompt_plus_new_tokens |
|
|
) |
|
|
self.assertEqual(discrep_length, 0) |
|
|
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) |
|
|
np.testing.assert_array_equal(discrep_only, np.array([[]])) |
|
|
|
|
|
def test_no_new_tokens(self): |
|
|
prompt = np.array([[1, 2, 3]]) |
|
|
prompt_plus_new_tokens = np.array([[1, 2, 3]]) |
|
|
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( |
|
|
prompt, prompt_plus_new_tokens |
|
|
) |
|
|
self.assertEqual(discrep_length, 0) |
|
|
np.testing.assert_array_equal(new_tokens_only, np.array([[]])) |
|
|
np.testing.assert_array_equal(discrep_only, np.array([[]])) |
|
|
|
|
|
|
|
|
class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase): |
|
|
def setUp(self): |
|
|
checkpoint = "EleutherAI/pythia-160m-deduped" |
|
|
self.assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint) |
|
|
self.assistant_model.generation_config.assistant_confidence_threshold = 0.4 |
|
|
self.model_kwargs = {} |
|
|
self.input_ids = torch.randint(1, 10, (1, 9)) |
|
|
self.candidate_generator = AssistedCandidateGenerator( |
|
|
input_ids=self.input_ids, |
|
|
assistant_model=self.assistant_model, |
|
|
generation_config=self.assistant_model.generation_config, |
|
|
model_kwargs=self.model_kwargs, |
|
|
) |
|
|
self.candidate_generator.probs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] |
|
|
self.original_probs = self.candidate_generator.probs |
|
|
self.original_threshold = self.assistant_model.generation_config.assistant_confidence_threshold |
|
|
|
|
|
def assert_no_sklearn(self): |
|
|
with patch("transformers.utils.import_utils._sklearn_available", False): |
|
|
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) |
|
|
self.assertEqual(self.candidate_generator.matches, self.original_matches) |
|
|
self.assertEqual(self.candidate_generator.probs, self.original_probs) |
|
|
self.assertEqual( |
|
|
self.assistant_model.generation_config.assistant_confidence_threshold, self.original_threshold |
|
|
) |
|
|
|
|
|
@parameterized.expand([(is_sklearn_available(),), (False,)]) |
|
|
def test_update_candidate_strategy_no_matches_short(self, sklearn_available): |
|
|
print("test_update_candidate_strategy_no_matches_short") |
|
|
self.original_matches = [] |
|
|
self.candidate_generator.matches = self.original_matches |
|
|
self.num_matches = 0 |
|
|
|
|
|
if sklearn_available: |
|
|
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) |
|
|
self.assertEqual(self.candidate_generator.matches, [0]) |
|
|
self.assertEqual(self.candidate_generator.probs, [0.9]) |
|
|
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4) |
|
|
else: |
|
|
self.assert_no_sklearn() |
|
|
|
|
|
@parameterized.expand([(is_sklearn_available(),), (False,)]) |
|
|
def test_update_candidate_strategy_with_mix_matches_3(self, sklearn_available): |
|
|
self.original_matches = [1, 0, 1, 0, 1] |
|
|
self.candidate_generator.matches = self.original_matches |
|
|
self.num_matches = 3 |
|
|
if sklearn_available: |
|
|
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) |
|
|
self.assertEqual(self.candidate_generator.matches, [1, 0, 1, 0, 1, 1, 1, 1, 0]) |
|
|
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]) |
|
|
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2) |
|
|
else: |
|
|
self.assert_no_sklearn() |
|
|
|
|
|
@parameterized.expand([(is_sklearn_available(),), (False,)]) |
|
|
def test_update_candidate_strategy_with_matches_4(self, sklearn_available): |
|
|
self.original_matches = [1, 1, 1, 1, 1] |
|
|
self.candidate_generator.matches = self.original_matches |
|
|
self.num_matches = 4 |
|
|
if sklearn_available: |
|
|
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) |
|
|
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 1]) |
|
|
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]) |
|
|
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4) |
|
|
else: |
|
|
self.assert_no_sklearn() |
|
|
|
|
|
@parameterized.expand([(is_sklearn_available(),), (False,)]) |
|
|
def test_update_candidate_strategy_with_matches_3(self, sklearn_available): |
|
|
self.original_matches = [1, 1, 1, 1, 1] |
|
|
self.candidate_generator.matches = self.original_matches |
|
|
self.num_matches = 3 |
|
|
if sklearn_available: |
|
|
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) |
|
|
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 0]) |
|
|
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]) |
|
|
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2) |
|
|
else: |
|
|
self.assert_no_sklearn() |
|
|
|
|
|
@parameterized.expand([(is_sklearn_available(),), (False,)]) |
|
|
def test_update_candidate_strategy_with_matches_2(self, sklearn_available): |
|
|
self.original_matches = [1, 1, 1, 1, 1] |
|
|
self.candidate_generator.matches = self.original_matches |
|
|
self.num_matches = 2 |
|
|
if sklearn_available: |
|
|
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) |
|
|
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 0]) |
|
|
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]) |
|
|
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.3) |
|
|
else: |
|
|
self.assert_no_sklearn() |
|
|
|
|
|
@parameterized.expand([(is_sklearn_available(),), (False,)]) |
|
|
def test_update_candidate_strategy_with_matches_1(self, sklearn_available): |
|
|
self.original_matches = [1, 1, 1, 1, 1] |
|
|
self.candidate_generator.matches = self.original_matches |
|
|
self.num_matches = 1 |
|
|
if sklearn_available: |
|
|
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) |
|
|
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 0]) |
|
|
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3]) |
|
|
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4) |
|
|
else: |
|
|
self.assert_no_sklearn() |
|
|
|