bhavitvyamalik's picture
weights and model
668c729
raw
history blame
39.8 kB
from typing import Dict, Optional
import flax
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
import numpy as np
from jax import lax
from transformers.file_utils import ModelOutput
from transformers.generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxGreedySearchOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using greedy search.
Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
The generated sequences.
"""
sequences: jax_xla.DeviceArray = None
@flax.struct.dataclass
class FlaxSampleOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using sampling.
Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
The generated sequences.
"""
sequences: jax_xla.DeviceArray = None
@flax.struct.dataclass
class FlaxBeamSearchOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using greedy search.
Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
The generated sequences.
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`):
The scores (log probabilites) of the generated sequences.
"""
sequences: jax_xla.DeviceArray = None
scores: jax_xla.DeviceArray = None
@flax.struct.dataclass
class GreedyState:
cur_len: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
running_token: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
@flax.struct.dataclass
class SampleState:
cur_len: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
running_token: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
prng_key: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
@flax.struct.dataclass
class BeamSearchState:
cur_len: jax_xla.DeviceArray
running_sequences: jax_xla.DeviceArray
running_scores: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
scores: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
class FlaxCLIPVisionMBartGenerationMixin:
"""
A class containing all of the functions supporting generation, to be used as a mixin in
:class:`~transformers.FlaxPreTrainedModel`.
"""
@staticmethod
def _run_loop_in_debug(cond_fn, body_fn, init_state):
"""
Run generation in untraced mode. This should only be used for debugging purposes.
"""
state = init_state
while cond_fn(state):
state = body_fn(state)
return state
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not (
argument.startswith("decoder_") or argument.startswith("cross_attn")
)
}
model_kwargs["encoder_outputs"] = self.encode(
input_ids, return_dict=True, **encoder_kwargs
)
return model_kwargs
@staticmethod
def _expand_to_num_beams(tensor, num_beams):
return jnp.broadcast_to(
tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]
)
def generate(
self,
input_ids: jax_xla.DeviceArray,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
do_sample: Optional[bool] = None,
prng_key: Optional[jax_xla.DeviceArray] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
no_repeat_ngram_size: Optional[int] = None,
min_length: Optional[int] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
**model_kwargs,
):
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
and, multinomial sampling.
Apart from :obj:`input_ids`, all the arguments below will default to the value of the attribute of the same
name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the
default values of those config.
Most of these parameters are explained in more detail in `this blog post
<https://huggingface.co/blog/how-to-generate>`__.
Parameters:
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use sampling ; use greedy decoding otherwise.
temperature (:obj:`float`, `optional`, defaults to 1.0):
The value used to module the next token probabilities.
top_k (:obj:`int`, `optional`, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (:obj:`float`, `optional`, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
higher are kept for generation.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
bos_token_id (:obj:`int`, `optional`):
The id of the `beginning-of-sequence` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
num_beams (:obj:`int`, `optional`, defaults to 1):
Number of beams for beam search. 1 means no beam search.
decoder_start_token_id (:obj:`int`, `optional`):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
a considerably slower runtime.
params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`):
Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
:class:`~transformers.file_utils.ModelOutput`.
Examples::
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids
>>> # generate candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# set init values
max_length = (
max_length
if max_length is not None
else self.config.mbart_config.max_length
)
bos_token_id = (
bos_token_id
if bos_token_id is not None
else self.config.mbart_config.bos_token_id
)
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.config.mbart_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.config.mbart_config.eos_token_id
)
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id
else self.config.mbart_config.decoder_start_token_id
)
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
if decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError(
"`decoder_start_token_id` has to be defined for encoder-decoder generation."
)
if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs
)
# prepare decoder_input_ids for generation
input_ids = (
jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
)
do_sample = (
do_sample if do_sample is not None else self.config.mbart_config.do_sample
)
num_beams = (
num_beams if num_beams is not None else self.config.mbart_config.num_beams
)
if not do_sample and num_beams == 1:
logits_processor = self._get_logits_processor(
no_repeat_ngram_size,
min_length,
max_length,
eos_token_id,
forced_bos_token_id,
forced_eos_token_id,
)
return self._greedy_search(
input_ids,
max_length,
pad_token_id,
eos_token_id,
logits_processor=logits_processor,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
elif do_sample and num_beams == 1:
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature
)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size,
min_length,
max_length,
eos_token_id,
forced_bos_token_id,
forced_eos_token_id,
)
return self._sample(
input_ids,
max_length,
pad_token_id,
eos_token_id,
prng_key,
logits_warper=logits_warper,
logits_processor=logits_processor,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
elif not do_sample and num_beams > 1:
# broadcast input_ids & encoder_outputs
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"][
"last_hidden_state"
] = self._expand_to_num_beams(
model_kwargs["encoder_outputs"]["last_hidden_state"],
num_beams=num_beams,
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=num_beams
)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size,
min_length,
max_length,
eos_token_id,
forced_bos_token_id,
forced_eos_token_id,
)
return self._beam_search(
input_ids,
max_length,
pad_token_id,
eos_token_id,
length_penalty=length_penalty,
early_stopping=early_stopping,
logits_processor=logits_processor,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
else:
raise NotImplementedError("`Beam sampling is currently not implemented.")
def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None
) -> FlaxLogitsProcessorList:
"""
This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant
:obj:`~transformers.FlaxLogitsWarper` instances used for multinomial sampling.
"""
# init warp parameters
top_k = top_k if top_k is not None else self.config.mbart_config.top_k
top_p = top_p if top_p is not None else self.config.mbart_config.top_p
temperature = (
temperature
if temperature is not None
else self.config.mbart_config.temperature
)
# instantiate warpers list
warpers = FlaxLogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if temperature is not None and temperature != 1.0:
warpers.append(FlaxTemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
if top_p is not None and top_p < 1.0:
warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
return warpers
def _get_logits_processor(
self,
no_repeat_ngram_size: int,
min_length: int,
max_length: int,
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
) -> FlaxLogitsProcessorList:
"""
This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant
:obj:`~transformers.FlaxLogitsProcessor` instances used to modify the scores of the language model head.
"""
processors = FlaxLogitsProcessorList()
# init warp parameters
no_repeat_ngram_size = (
no_repeat_ngram_size
if no_repeat_ngram_size is not None
else self.config.mbart_config.no_repeat_ngram_size
)
min_length = (
min_length
if min_length is not None
else self.config.mbart_config.min_length
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.config.mbart_config.eos_token_id
)
forced_bos_token_id = (
forced_bos_token_id
if forced_bos_token_id is not None
else self.config.mbart_config.forced_bos_token_id
)
forced_eos_token_id = (
forced_eos_token_id
if forced_eos_token_id is not None
else self.config.mbart_config.forced_eos_token_id
)
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id))
if forced_bos_token_id is not None:
processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(
FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)
)
return processors
def _greedy_search(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
):
# init values
max_length = (
max_length
if max_length is not None
else self.config.mbart_config.max_length
)
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.config.mbart_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.config.mbart_config.eos_token_id
)
batch_size, cur_len = input_ids.shape
eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
cur_len = jnp.array(cur_len)
# per batch-item holding current token in loop.
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
# per batch-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
model = self.decode if self.config.is_encoder_decoder else self
# initialize model specific kwargs
model_kwargs = self.prepare_inputs_for_generation(
input_ids, max_length, **model_kwargs
)
# initialize state
state = GreedyState(
cur_len=cur_len,
sequences=sequences,
running_token=input_ids,
is_sent_finished=is_sent_finished,
model_kwargs=model_kwargs,
)
def greedy_search_cond_fn(state):
"""state termination condition fn."""
has_reached_max_length = state.cur_len == max_length
all_sequence_finished = jnp.all(state.is_sent_finished)
finish_generation = jnp.logical_or(
has_reached_max_length, all_sequence_finished
)
return ~finish_generation
def greedy_search_body_fn(state):
"""state update fn."""
model_outputs = model(
state.running_token, params=params, **state.model_kwargs
)
logits = model_outputs.logits[:, -1]
# apply min_length, ...
logits = logits_processor(state.sequences, logits, state.cur_len)
next_token = jnp.argmax(logits, axis=-1)
next_is_sent_finished = state.is_sent_finished | (
next_token == eos_token_id
)
next_token = (
next_token * ~next_is_sent_finished
+ pad_token_id * next_is_sent_finished
)
next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(
state.sequences, next_token, (0, state.cur_len)
)
next_model_kwargs = self.update_inputs_for_generation(
model_outputs, state.model_kwargs
)
return GreedyState(
cur_len=state.cur_len + 1,
sequences=next_sequences,
running_token=next_token,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
)
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
if input_ids.shape[1] > 1:
state = greedy_search_body_fn(state)
if not trace:
state = self._run_loop_in_debug(
greedy_search_cond_fn, greedy_search_body_fn, state
)
else:
state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)
return FlaxGreedySearchOutput(sequences=state.sequences)
def _sample(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
prng_key: Optional[jax_xla.DeviceArray] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
logits_warper: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
):
# init values
max_length = (
max_length
if max_length is not None
else self.config.mbart_config.max_length
)
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.config.mbart_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.config.mbart_config.eos_token_id
)
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
batch_size, cur_len = input_ids.shape
eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
cur_len = jnp.array(cur_len)
# per batch-item holding current token in loop.
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
# per batch-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
model = self.decode if self.config.is_encoder_decoder else self
# initialize model specific kwargs
model_kwargs = self.prepare_inputs_for_generation(
input_ids, max_length, **model_kwargs
)
# initialize state
state = SampleState(
cur_len=cur_len,
sequences=sequences,
running_token=input_ids,
is_sent_finished=is_sent_finished,
prng_key=prng_key,
model_kwargs=model_kwargs,
)
def sample_search_cond_fn(state):
"""state termination condition fn."""
has_reached_max_length = state.cur_len == max_length
all_sequence_finished = jnp.all(state.is_sent_finished)
finish_generation = jnp.logical_or(
has_reached_max_length, all_sequence_finished
)
return ~finish_generation
def sample_search_body_fn(state):
"""state update fn."""
prng_key, prng_key_next = jax.random.split(state.prng_key)
model_outputs = model(
state.running_token, params=params, **state.model_kwargs
)
logits = model_outputs.logits[:, -1]
# apply min_length, ...
logits = logits_processor(state.sequences, logits, state.cur_len)
# apply top_k, top_k, temperature
logits = logits_warper(logits, logits, state.cur_len)
next_token = jax.random.categorical(
prng_key, model_outputs.logits[:, -1], axis=-1
)
next_is_sent_finished = state.is_sent_finished | (
next_token == eos_token_id
)
next_token = (
next_token * ~next_is_sent_finished
+ pad_token_id * next_is_sent_finished
)
next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(
state.sequences, next_token, (0, state.cur_len)
)
next_model_kwargs = self.update_inputs_for_generation(
model_outputs, state.model_kwargs
)
return SampleState(
cur_len=state.cur_len + 1,
sequences=next_sequences,
running_token=next_token,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
prng_key=prng_key_next,
)
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
if input_ids.shape[1] > 1:
state = sample_search_body_fn(state)
if not trace:
state = self._run_loop_in_debug(
sample_search_cond_fn, sample_search_body_fn, state
)
else:
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
return FlaxSampleOutput(sequences=state.sequences)
def _beam_search(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
):
"""
This beam search function is heavily inspired by Flax's official example:
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
"""
def flatten_beam_dim(tensor):
"""Flattens the first two dimensions of a non-scalar array."""
# ignore scalars (e.g. cache index)
if tensor.ndim == 0:
return tensor
return tensor.reshape(
(tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]
)
def unflatten_beam_dim(tensor, batch_size, num_beams):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
# ignore scalars (e.g. cache index)
if tensor.ndim == 0:
return tensor
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
def gather_beams(nested, beam_indices, batch_size, new_num_beams):
"""
Gathers the beam slices indexed by beam_indices into new beam array.
"""
batch_indices = jnp.reshape(
jnp.arange(batch_size * new_num_beams) // new_num_beams,
(batch_size, new_num_beams),
)
def gather_fn(tensor):
# ignore scalars (e.g. cache index)
if tensor.ndim == 0:
return tensor
else:
return tensor[batch_indices, beam_indices]
return jax.tree_map(gather_fn, nested)
# init values
max_length = (
max_length
if max_length is not None
else self.config.mbart_config.max_length
)
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.config.mbart_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.config.mbart_config.eos_token_id
)
length_penalty = (
length_penalty
if length_penalty is not None
else self.config.mbart_config.length_penalty
)
early_stopping = (
early_stopping
if early_stopping is not None
else self.config.mbart_config.early_stopping
)
batch_size, num_beams, cur_len = input_ids.shape
eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
cur_len = jnp.array(cur_len)
# per batch,beam-item holding current token in loop.
sequences = jnp.full(
(batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32
)
running_sequences = jnp.full(
(batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32
)
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
# per batch,beam-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
# per batch,beam-item score, logprobs
running_scores = jnp.tile(
jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1]
)
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
model = self.decode if self.config.is_encoder_decoder else self
# flatten beam dim
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
model_kwargs["encoder_outputs"]["last_hidden_state"]
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = flatten_beam_dim(
model_kwargs["attention_mask"]
)
# initialize model specific kwargs
model_kwargs = self.prepare_inputs_for_generation(
flatten_beam_dim(input_ids), max_length, **model_kwargs
)
# initialize state
state = BeamSearchState(
cur_len=cur_len,
running_sequences=running_sequences,
running_scores=running_scores,
sequences=sequences,
scores=scores,
is_sent_finished=is_sent_finished,
model_kwargs=model_kwargs,
)
def beam_search_cond_fn(state):
"""beam search state termination condition fn."""
# 1. is less than max length?
not_max_length_yet = state.cur_len < max_length
# 2. can the new beams still improve?
best_running_score = state.running_scores[:, -1:] / (
max_length ** length_penalty
)
worst_finished_score = jnp.where(
state.is_sent_finished,
jnp.min(state.scores, axis=1, keepdims=True),
np.array(-1.0e7),
)
improvement_still_possible = jnp.all(
worst_finished_score < best_running_score
)
# 3. is there still a beam that has not finished?
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
return not_max_length_yet & still_open_beam & improvement_still_possible
def beam_search_body_fn(state):
"""beam search state update fn."""
# 1. Forward current tokens
# Collect the current position slice along length to feed the fast
# autoregressive decoder model. Flatten the beam dimension into batch
# dimension for feeding into the model.
# unflatten beam dimension
# Unflatten beam dimension in attention cache arrays
input_token = flatten_beam_dim(
lax.dynamic_slice(
state.running_sequences,
(0, 0, state.cur_len - 1),
(batch_size, num_beams, 1),
)
)
model_outputs = model(input_token, params=params, **state.model_kwargs)
logits = unflatten_beam_dim(
model_outputs.logits[:, 0], batch_size, num_beams
)
cache = jax.tree_map(
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams),
model_outputs.past_key_values,
)
# 2. Compute log probs
# get log probabilities from logits,
# process logits with processors (*e.g.* min_length, ...), and
# add new logprobs to existing running logprobs scores.
log_probs = jax.nn.log_softmax(logits)
log_probs = logits_processor(
flatten_beam_dim(running_sequences),
flatten_beam_dim(log_probs),
state.cur_len,
)
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
vocab_size = log_probs.shape[2]
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
# 3. Retrieve top-K
# Each item in batch has num_beams * vocab_size candidate sequences.
# For each item, get the top 2*k candidates with the highest log-
# probabilities. We gather the top 2*K beams here so that even if the best
# K sequences reach EOS simultaneously, we have another K sequences
# remaining to continue the live beam search.
# Gather the top 2*K scores from _all_ beams.
# Gather 2*k top beams.
# Recover the beam index by floor division.
# Recover token id by modulo division and expand Id array for broadcasting.
# Update sequences for the 2*K top-k new sequences.
beams_to_keep = 2 * num_beams
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
topk_beam_indices = topk_indices // vocab_size
topk_running_sequences = gather_beams(
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
)
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
topk_sequences = lax.dynamic_update_slice(
topk_running_sequences, topk_ids, (0, 0, state.cur_len)
)
# 4. Check which sequences have ended
# Update current sequences:
# Did any of these sequences reach an end marker?
# To prevent these just finished sequences from being added to the current sequences
# set of active beam search sequences, set their log probs to a very large
# negative value.
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
# 5. Get running sequences scores for next
# Determine the top k beam indices (from top 2*k beams) from log probs
# and gather top k beams (from top 2*k beams).
next_topk_indices = jnp.flip(
lax.top_k(topk_log_probs, k=num_beams)[1], axis=1
)
next_running_sequences, next_running_scores = gather_beams(
[topk_sequences, topk_log_probs],
next_topk_indices,
batch_size,
num_beams,
)
# 6. Process topk logits
# Further process log probs:
# - add length penalty
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (state.cur_len ** length_penalty)
beams_in_batch_are_full = (
jnp.broadcast_to(
state.is_sent_finished.all(axis=-1, keepdims=True),
did_topk_just_finished.shape,
)
& early_stopping
)
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
topk_log_probs += add_penalty * np.array(-1.0e7)
# 7. Get scores, sequences, is sentence finished for next.
# Combine sequences, scores, and flags along the beam dimension and compare
# new finished sequence scores to existing finished scores and select the
# best from the new set of beams
merged_sequences = jnp.concatenate(
[state.sequences, topk_sequences], axis=1
)
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
merged_is_sent_finished = jnp.concatenate(
[state.is_sent_finished, did_topk_just_finished], axis=1
)
topk_merged_indices = jnp.flip(
lax.top_k(merged_scores, k=num_beams)[1], axis=1
)
next_sequences, next_scores, next_is_sent_finished = gather_beams(
[merged_sequences, merged_scores, merged_is_sent_finished],
topk_merged_indices,
batch_size,
num_beams,
)
# 8. Update model kwargs.
# Determine the top k beam indices from the original set of all beams.
# With these, gather the top k beam-associated caches.
next_running_indices = gather_beams(
topk_beam_indices, next_topk_indices, batch_size, num_beams
)
next_cache = gather_beams(
cache, next_running_indices, batch_size, num_beams
)
model_outputs["past_key_values"] = jax.tree_map(
lambda x: flatten_beam_dim(x), next_cache
)
next_model_kwargs = self.update_inputs_for_generation(
model_outputs, state.model_kwargs
)
return BeamSearchState(
cur_len=state.cur_len + 1,
running_scores=next_running_scores,
running_sequences=next_running_sequences,
scores=next_scores,
sequences=next_sequences,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
)
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
state = beam_search_body_fn(state)
if not trace:
state = self._run_loop_in_debug(
beam_search_cond_fn, beam_search_body_fn, state
)
else:
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
# Account for the edge-case where there are no finished sequences for a
# particular batch item. If so, return running sequences for that batch item.
none_finished = jnp.any(state.is_sent_finished, axis=1)
sequences = jnp.where(
none_finished[:, None, None], state.sequences, state.running_sequences
)
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
# take best beam for each batch
sequences = sequences[:, -1]
scores = scores[:, -1]
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)