Spaces:
Runtime error
Runtime error
from typing import Dict, Optional | |
import jax | |
import jax.numpy as jnp | |
import jaxlib.xla_extension as jax_xla | |
from transformers.generation_flax_utils import FlaxGenerationMixin | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
class VaeFlaxGenerationMixin(FlaxGenerationMixin): | |
def generate( | |
self, | |
latent_codes: jax_xla.DeviceArray, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
bos_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
decoder_start_token_id: Optional[int] = None, | |
do_sample: Optional[bool] = None, | |
prng_key: Optional[jax_xla.DeviceArray] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
temperature: Optional[float] = None, | |
num_beams: Optional[int] = None, | |
no_repeat_ngram_size: Optional[int] = None, | |
min_length: Optional[int] = None, | |
forced_bos_token_id: Optional[int] = None, | |
forced_eos_token_id: Optional[int] = None, | |
length_penalty: Optional[float] = None, | |
early_stopping: Optional[bool] = None, | |
trace: bool = True, | |
params: Optional[Dict[str, jax_xla.DeviceArray]] = None, | |
**model_kwargs, | |
): | |
r""" | |
Generates sequences for models with a language modeling head. The method currently supports greedy decoding, | |
and, multinomial sampling. | |
Apart from :obj:`latent_codes`, all the arguments below will default to the value of the attribute of the same | |
name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the | |
default values of those config. | |
Most of these parameters are explained in more detail in `this blog post | |
<https://huggingface.co/blog/how-to-generate>`__. | |
Parameters: | |
latent_codes (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, n_latent_tokens, latent_token_dim)`, `optional`): | |
The sequence used as a prompt for the generation. | |
max_length (:obj:`int`, `optional`, defaults to 20): | |
The maximum length of the sequence to be generated. | |
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to use sampling ; use greedy decoding otherwise. | |
temperature (:obj:`float`, `optional`, defaults to 1.0): | |
The value used to module the next token probabilities. | |
top_k (:obj:`int`, `optional`, defaults to 50): | |
The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
top_p (:obj:`float`, `optional`, defaults to 1.0): | |
If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or | |
higher are kept for generation. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
bos_token_id (:obj:`int`, `optional`): | |
The id of the `beginning-of-sequence` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
num_beams (:obj:`int`, `optional`, defaults to 1): | |
Number of beams for beam search. 1 means no beam search. | |
decoder_start_token_id (:obj:`int`, `optional`): | |
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. | |
trace (:obj:`bool`, `optional`, defaults to :obj:`True`): | |
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to | |
a considerably slower runtime. | |
params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`): | |
Optionally the model parameters can be passed. Can be useful for parallelized generation. | |
model_kwargs: | |
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. | |
Return: | |
:class:`~transformers.file_utils.ModelOutput`. | |
Examples:: | |
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM | |
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") | |
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2") | |
>>> input_context = "The dog" | |
>>> # encode input context | |
>>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids | |
>>> # generate candidates using sampling | |
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) | |
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
""" | |
# set init values | |
max_length = max_length if max_length is not None else self.config.max_length | |
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
decoder_start_token_id = ( | |
decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id | |
) | |
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) | |
if decoder_start_token_id is None and self.config.is_encoder_decoder: | |
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.") | |
model_kwargs['latent_codes'] = latent_codes | |
if self.config.is_encoder_decoder: | |
# add encoder_outputs to model_kwargs | |
# NOTE: Don't prepare encoder outputs, instead rely on latent_codes. | |
# model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) | |
# prepare decoder_input_ids for generation | |
input_ids = jnp.ones((latent_codes.shape[0], 1), dtype="i4") * decoder_start_token_id | |
do_sample = do_sample if do_sample is not None else self.config.do_sample | |
num_beams = num_beams if num_beams is not None else self.config.num_beams | |
if not do_sample and num_beams == 1: | |
logits_processor = self._get_logits_processor( | |
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id | |
) | |
return self._greedy_search( | |
input_ids, | |
max_length, | |
pad_token_id, | |
eos_token_id, | |
logits_processor=logits_processor, | |
trace=trace, | |
params=params, | |
model_kwargs=model_kwargs, | |
) | |
elif do_sample and num_beams == 1: | |
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) | |
logits_processor = self._get_logits_processor( | |
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id | |
) | |
return self._sample( | |
input_ids, | |
max_length, | |
pad_token_id, | |
eos_token_id, | |
prng_key, | |
logits_warper=logits_warper, | |
logits_processor=logits_processor, | |
trace=trace, | |
params=params, | |
model_kwargs=model_kwargs, | |
) | |
elif not do_sample and num_beams > 1: | |
# broadcast input_ids & encoder_outputs | |
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) | |
if "encoder_outputs" in model_kwargs: | |
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( | |
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams | |
) | |
if "attention_mask" in model_kwargs: | |
model_kwargs["attention_mask"] = self._expand_to_num_beams( | |
model_kwargs["attention_mask"], num_beams=num_beams | |
) | |
logits_processor = self._get_logits_processor( | |
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id | |
) | |
return self._beam_search( | |
input_ids, | |
max_length, | |
pad_token_id, | |
eos_token_id, | |
length_penalty=length_penalty, | |
early_stopping=early_stopping, | |
logits_processor=logits_processor, | |
trace=trace, | |
params=params, | |
model_kwargs=model_kwargs, | |
) | |
else: | |
raise NotImplementedError("`Beam sampling is currently not implemented.") | |