juancopi81's picture
Add t5x and mt3 models
b100e1c
raw
history blame
49.7 kB
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast decoding routines for inference from a trained model."""
import functools
from typing import Any, Callable, Mapping, Optional, Tuple, Union
import flax
from flax import traverse_util
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import numpy as np
PyTreeDef = type(jax.tree_structure(None))
SamplingLoopState = Tuple[int, jnp.ndarray, Mapping[str, jnp.ndarray],
jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
# Constants
# "Effective negative infinity" constant for masking in beam search.
NEG_INF = np.array(-1.0e7)
# Temperatures lower than this are considered 0.0, which is handled specially
# with a conditional. This is to avoid numeric issues from exponentiating on
# 1.0/temperature when temperature is close to 0.0.
MIN_TEMPERATURE = np.array(1e-4)
#------------------------------------------------------------------------------
# Temperature Sampling
#------------------------------------------------------------------------------
_dynamic_update_vector_slice_in_dim = jax.vmap(
lax.dynamic_update_slice_in_dim, in_axes=(0, 0, 0, None))
def _is_tracer(value: Any):
return isinstance(value, jax.core.Tracer)
def temperature_sample(
inputs: jnp.ndarray,
cache: Mapping[str, jnp.ndarray],
tokens_to_logits: Callable[[jnp.ndarray, Mapping[str, jnp.ndarray]],
Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]],
eos_id: int,
decode_rng: Optional[jnp.ndarray] = None,
num_decodes: int = 1,
temperature: Union[float, jnp.ndarray] = 1.0,
topk: int = 1,
topp: float = 0.0,
cache_offset: int = 0,
initial_index: Optional[jnp.ndarray] = None,
max_decode_steps: Optional[Union[int, jnp.ndarray]] = None,
max_decode_steps_hard_limit: Optional[int] = None,
rescale_log_probs: bool = True,
state_callback_fn: Optional[Callable[[SamplingLoopState],
SamplingLoopState]] = None,
logit_callback_fn: Optional[Callable[[jnp.ndarray, SamplingLoopState],
jnp.ndarray]] = None
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Temperature sampling for language model generation.
The temperature sampling is performed `num_decodes` times in a vectorized
manner by expanding the batch dimension. This is similar to how beam search
expands the batch dimension to process each batch element with multiple beams.
This function dynamically updates the `inputs` array by sampling from the
model logits, which is provided by `tokens_to_logits` callable. The input
sequences are expanded at the end, populated and sliced by dropping the first
position.
If `inputs` has non-zero entries, those values are not modified, i.e.,
the sampled values for those positions are discarded. This simulates the
teacher forcing on the prefix positions.
There are a few important observations related to this function.
1. The `inputs` is assumed to be a non-packed sequence.
2. If `initial_index=None`, then `inputs`[:, 0] is ignored. We will use 0 as a
BOS token to start the generation. This inherently assumes that `inputs` is
already shifted to the right by one position. If `initial_index=an_array`,
the token values at `inputs`[:, initial_index] are used as the token to
start the generation.
3. The loop index, i, is a vector of shape [batch_size]. When beginning
generation from scratch, each value will always have the same value. When
beginning with a partially filled cache, the loop index of different
elements can differ, via providing a value for `initial_index`.
3. Unless all batch elements generated the eos_id before reaching the end, we
always make `max_decode_len = inputs.shape[1]` number of calls to
`tokens_to_logits` when decoding from scratch and
`max_decode_len - jnp.minimum(initial_index)` number of calls when starting
from a partially filled cache.
4. Let `output` be the output sequences, i.e.,`sequences`[:, 1:]. Then
`output`[:, j] are the tokens generated when the while loop counter `i =
j`. Therefore, we generate the last token when `i = max_decode_len - 1`
and exit the while loop as all `i`s are incremented to `max_decode_len`.
5. Once `eos_id = 1` is generated, the subsequent predictions are all replaced
by padding token 0.
6. When using a partially filled cache, different batch elements can have
different lengths. This means an input that has a longer input will have
fewer steps until its `i` value reaches `max_decode_len` than an input with
a shorter input. We keep these longer examples alive, doing busy work
continually overwriting a new garbage token at the end of the sequence
until shorter examples finish.
7. When using a partially filled cache, providing a value for `initial_index`,
the attention cache index should be a vector of [batch_size].
We show three examples to illustrate how this function works. In addition to
input and output of the function, we also show two intermediate values:
`expanded_prompt_inputs` and `final_sequences`. Also for simplicity, the
examples are limited to `num_decodes = 1` usage and the `num_decodes`
dimension is omitted.
```
Example 1:
inputs = [0, 5, 6, 1, 0]
expanded_prompt_inputs = [0, 5, 6, 1, 0, 0]
final_sequences = [0, 5, 6, 1, a, b] # before slicing.
output = [5, 6, 1, a, b]
where `a` is prediction while taking 1 as input and `b` is prediction while
taking `a` as input.
Example 2 (early stopping):
inputs = [[0, 5, 1, 0, 0, 0, 0],
[0, 8, 0, 0, 0, 0, 0]
expanded_prompt_inputs = [[0, 5, 1, 0, 0, 0, 0, 0],
[0, 8, 0, 0, 0, 0, 0, 0]
final_sequences = [[0, 5, 1, a, b, c=1, 0, 0],
[0, 8, d, e, f=1, g=0, 0, 0]]
output = [[5, 1, a, b, c=1, 0, 0],
[8, d, e, f=1, g=0, 0, 0]]
In this example, there are two sequences. Let's look at sequence 0. The
first generated token is `a`, which is in turn used to generate `b`.
Finally, `c = 1` is generated with the input `b`. Then the loop terminates
early because 1 is the `eos_id`.
Now consider sequence 1. The when `f = 1` was generated, it is considered
done. Since sequence 0 is not done at this point, the next prediction, i.e.,
`g` is zerod out. This continues until the end.
Example 3 (prefilled cache):
inputs = [[0, 5, 2, 6, 1, 0],
[0, 8, 1, 0, 0, 0]]
expanded_prompt_inputs = [[0, 5, 2, 6, 1, 0, 0, 0],
[0, 8, 1, 0, 0, 0, 0, 0]]
max_decode_length = 6
i = [4, 2]
input_tokens = [[1],
[1]]
output_tokens = [[a],
[b]]
expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, 0, 0],
[0, 8, 1, b, 0, 0, 0, 0]]
i = [5, 3]
input_tokens = [[a],
[b]]
output_tokens = [[c],
[d]]
expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, 0],
[0, 8, 1, b, d, 0, 0, 0]]
i = [6, 4]
input_tokens = [[c],
[d]]
output_tokens = [[y],
[e]]
expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, y],
[0, 8, 1, b, d, e, 0, 0]]
i = [6, 5]
input_tokens = [[z],
[e]]
output_tokens = [[z],
[f]]
expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, z],
[0, 8, 1, b, d, e, f, 0]]
i = [6, 6]
exit
outputs = [[5, 2, 6, 1, a, c],
[8, 1, b, d, e, f]]
In this example, there are two sequences with different input lengths. Thus
the two caches had been filled to different positions. As we decode, the
first sequence hits the max decode length before the second. In order to
avoid prematurely ending decoding for the second sequence, the first
sequence continually overwrites the final token.
Example 4 (prefilled cache and max decode steps):
inputs = [[0, 2, 0, 0, 0, 0, 0, 0],
[0, 3, 4, 0, 0, 0, 0, 0]]
expanded_prompt_inputs = [[0, 2, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 3, 4, 0, 0, 0, 0, 0, 0, 0]]
initial_indices = [1, 2]
max_decode_step = 2
Then `max_decode_len = [3, 4]`.
i = [1, 2]
input_tokens = [[2],
[4]]
output_tokens = [[a],
[b]]
expanded_prompt_inputs = [[0, 2, a, 0, 0, 0, 0, 0, 0, 0]
[0, 3, 4, b, 0, 0, 0, 0, 0, 0]]
i = [2, 3]]
input_tokens = [[a],
[b]]
output_tokens = [[c],
[d]]
expanded_prompt_inputs = [[0, 2, a, c, 0, 0, 0, 0, 0, 0]
[0, 3, 4, b, d, 0, 0, 0, 0, 0]]
This is the last while loop iteration with i == max_decode_len - 1.
outputs = [[2, a, c, 0, 0, 0, 0, 0]
[3, 4, b, d, 0, 0, 0, 0]]
```
Args:
inputs: array: [batch_size, max_decode_len] int32 sequence of tokens.
cache: flax attention cache.
tokens_to_logits: fast autoregressive decoder function taking single token
slices and cache and returning next-token logits and updated cache.
eos_id: int: end-of-sentence token for target vocabulary.
decode_rng: JAX PRNGKey.
num_decodes: number of decoded sequences to be returned.
temperature: float: sampling temperature factor. As it approaches zero this
becomes equivalent to greedy sampling.
topk: integer: if nonzero only use the top-k logits to sample next token, if
zero don't use any cutoff and sample from full logits over vocabulary.
topp: float: if nonzero only use the smallest number of logits whose
cumulative sum of probs adds up to (at least) topp. Will raise ValueError
if it's nonzero when topk is nonzero.
cache_offset: axis offset for cache, arising from scanned layers.
initial_index: Optional[array]: [batch_size] int32 a vector of loop indexes
to start decoding at.
max_decode_steps: int: an optional maximum number of decoding steps. If
None, it will decode until the full input shape `inputs.shape[1]` is
filled. max_decode_steps begins counting after the prompt, so it will
decode at most len(prompt) + max_decode_steps tokens.
max_decode_steps_hard_limit: int: an optional fixed hard limit on
max_decode_steps. If this is set (not None and > 0), and max_decode_steps
is also set, then max_decode_steps will be clipped to this limit. The
value max_decode_steps can be an ndarray, but max_decode_steps_hard_limit
must be a Python integer or None.
rescale_log_probs: bool: whether to apply temperature, topp, and topk
rescaling to the log probs which are returned. If True, the log_probs will
include these transformations (for example, with topk=1, all log_probs
will be identically 0.0). If False, the log_probs will not be affected,
and topk/topp/temperature will not affect sequence probabilities.
state_callback_fn: Function that modifies the sampling loop state before
each step. This can be used to manipulate any part of the state either
on the accelerator or on the host using host callback. The function
should take a tuple of type SamplingLoopState as argument, and it
returns the updated state. See `decoding_test.py` for an example usage.
logit_callback_fn: Function that modifies the logits before each temperature
sampling step. The function should take arguments (logits, state) and it
should return the modified logits. See `decoding_test.py` for an example
usage.
Returns:
A tuple (decodes, log_prob) where `decodes` is sampled sequences with shape
[batch_size, num_decodes, max_decode_len] sorted by `log_prob`, which is log
probability of each of the sampled sequences.
"""
if decode_rng is None:
decode_rng = jax.random.PRNGKey(0)
if (max_decode_steps_hard_limit is not None and
max_decode_steps_hard_limit > 0 and max_decode_steps is not None):
max_decode_steps = jnp.minimum(max_decode_steps,
max_decode_steps_hard_limit)
# [batch, len] -> [batch * num_decodes, len]
expanded_inputs = flat_batch_beam_expand(inputs, num_decodes)
expanded_cache = cache_map(
functools.partial(
flat_batch_beam_expand, beam_size=num_decodes, offset=cache_offset),
cache,
# When we start with a prefilled cache, the cache index is no longer a
# scalar that will broadcast across multiple decodes, it is a vector and
# needs to be updated to handle the multiple decodes.
apply_to_index=initial_index is not None)
if initial_index is not None:
initial_index = flat_batch_beam_expand(initial_index, num_decodes)
# expanded_decodes: [batch * num_decodes, len]
# expanded_log_prob: [batch * num_decodes]
expanded_decodes, expanded_log_prob = _temperature_sample_single_trial(
expanded_inputs,
expanded_cache,
tokens_to_logits,
eos_id,
decode_rng,
temperature,
topk,
topp,
initial_index=initial_index,
max_decode_steps=max_decode_steps,
rescale_log_probs=rescale_log_probs,
state_callback_fn=state_callback_fn,
logit_callback_fn=logit_callback_fn)
batch_size = inputs.shape[0]
# [batch * num_decodes, len] -> [batch, num_decodes, len]
decodes = unflatten_beam_dim(expanded_decodes, batch_size, num_decodes)
# [batch * num_decodes] -> [batch, num_decodes]
log_prob = unflatten_beam_dim(expanded_log_prob, batch_size, num_decodes)
# Sort `decodes` and `log_prob` by increasing log probabilities of the sampled
# sequence.
# [batch, num_decodes, 1]
idxs = jnp.expand_dims(jnp.argsort(log_prob, axis=-1), axis=-1)
# returns [batch, num_decodes, len], [batch, num_decodes] in sorted order.
return jnp.take_along_axis(
decodes, idxs, axis=1), jnp.take_along_axis(
log_prob, jnp.squeeze(idxs, axis=-1), axis=-1)
def _temperature_sample_single_trial(
inputs: jnp.ndarray,
cache: Mapping[str, jnp.ndarray],
tokens_to_logits: Callable[[jnp.ndarray, Mapping[str, jnp.ndarray]],
Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]],
eos_id: int,
prng_key: jnp.ndarray,
temperature: Union[float, jnp.ndarray] = 1.0,
topk: int = 20,
topp: Union[float, jnp.ndarray] = 0.0,
initial_index: Optional[jnp.ndarray] = None,
max_decode_steps: Optional[Union[int, jnp.ndarray]] = None,
rescale_log_probs: bool = True,
state_callback_fn: Optional[Callable[[SamplingLoopState],
SamplingLoopState]] = None,
logit_callback_fn: Optional[Callable[[jnp.ndarray, SamplingLoopState],
jnp.ndarray]] = None
) -> jnp.ndarray:
"""A helper function for `temperature_sample`."""
# We can check the values of topp and topk only if they are not dynamic.
if not _is_tracer(topp) and topp and topk:
raise ValueError('At most one of `topp` or `topk` may be non-zero.')
batch_size, max_decode_len = inputs.shape
if max_decode_steps is not None:
# We can check the max_decode_steps bounds only if it is not dynamic.
if not _is_tracer(max_decode_steps) and max_decode_steps > inputs.shape[1]:
raise ValueError('Cannot decode more steps than the sequence length.')
# The number of decode steps required to process the prefix is the number
# of non-zero tokens, since inputs[0] == 0 is the BOS token.
# `max_decode_len[j]` is the number of non-padding tokens in the jth element
# of the returned sequences capped at `len(inputs)`, assuming that the
# early stop doesn't occur. This is true with or without
# `max_decode_steps`.
# When the while loop index `i` for the `j`th element `i[j] =
# max_decode_len[j] - 1`, the generated token populate sequences[i[j]+1]].
# Since sequences[:, 0] is BOS token, the generated token is
# `max_decode_len[j]`th non-padding tokens and hence `j`th element is
# ended.
max_decode_len = jnp.sum(inputs != 0, axis=1) + max_decode_steps
max_decode_len = jnp.minimum(inputs.shape[1], max_decode_len)
# In the case of starting generation from a non-zero index, it is possible for
# one batch element to reach `max_decode_len` number of decoding steps before
# another. In order to let the last element decoder all the way to
# `max_decode_len` number of steps, we add a final garbage token to the end of
# the sequences. Any element that has reached `max_decode_len` before the rest
# of the elements will continually overwrite this token until all elements
# finish.
# [batch, length+1] -> [batch, length+2]
expanded_prompt_inputs = jnp.append(
inputs, jnp.zeros((batch_size, 2), dtype=inputs.dtype), axis=1)
end_marker = jnp.array(eos_id)
temperature = jnp.asarray(temperature)
# Initialize sampling loop state.
# initial loop PRNGKey
rng0 = prng_key
# the per batch-item holding current token in loop.
if initial_index is None:
# the per batch-item loop position counter.
i0 = jnp.zeros((batch_size), dtype=jnp.int32)
# the per batch-item holding current token in loop.
token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32)
else:
# the per batch-item loop position counter.
i0 = initial_index
# the per batch-item holding current token in loop.
# Select the token that the initial index is pointing to.
token0 = jnp.take_along_axis(
expanded_prompt_inputs, jnp.expand_dims(i0, axis=1), axis=1)
# per batch-item state bit indicating if sentence has finished.
ended0 = jnp.zeros((batch_size, 1), dtype=jnp.bool_)
# (batch, length+2) array containing prefix prompt tokens for sampling loop
# as well as the generated output of newly sampled tokens.
sequences0 = expanded_prompt_inputs
log_prob0 = jnp.zeros((batch_size,), dtype=jnp.float32)
# Sampling loop state is stored in a simple tuple.
sampling_loop_init_state = (i0, sequences0, cache, token0, ended0, rng0,
log_prob0)
# Initial eos count to be used to determine whether eos is "generated". Many
# inputs follow the format bos, inputs..., eos, targets..., eos. By counting
# the number of eos tokens we can detect when a new one is added, instead of
# just finding the one that probably ends the inputs.
# [batch, 1]
initial_eos_count = jnp.sum(sequences0 == end_marker, axis=-1, keepdims=True)
def sampling_loop_cond_fn(state: SamplingLoopState) -> bool:
"""Sampling loop termination condition."""
(_, _, _, _, ended, _, _) = state
# Have all sampled sequences reached an end marker?
# Different elements in the batch can be at different loop indices, if any
# of our examples are not at the end, keep going.
all_sequences_ended = jnp.all(ended)
return ~all_sequences_ended
def sampling_loop_body_fn(state: SamplingLoopState) -> SamplingLoopState:
"""Sampling loop state update."""
if state_callback_fn is not None:
state = state_callback_fn(state)
i, sequences, cache, cur_token, ended, rng, log_prob = state
# Split RNG for sampling.
rng1, rng2 = random.split(rng)
# Call fast-decoder model on current tokens to get next-position logits.
logits, new_cache = tokens_to_logits(cur_token, cache)
# Sample next token from logits.
if logit_callback_fn is not None:
logits = logit_callback_fn(logits, state)
def sample_logits_with_nonzero_temperature(logits):
scaled_logits = logits / jnp.maximum(temperature, MIN_TEMPERATURE)
if topk:
# Get top-k logits and their indices, sample within these top-k tokens.
topk_logits, _ = lax.top_k(scaled_logits, topk)
cutoff_logit = topk_logits[:, -1, None]
scaled_logits = jnp.where(scaled_logits < cutoff_logit,
jnp.full_like(scaled_logits, NEG_INF),
scaled_logits)
# When topp is dynamic, we always use it since we cannot check
# non-zeroness (but it will have no effect if topp is 0.0).
if _is_tracer(topp) or topp:
logits_sorted = jnp.sort(
scaled_logits, axis=-1)[:, ::-1] # sort descending
sorted_cum_probs = jnp.cumsum(
jax.nn.softmax(logits_sorted, axis=-1), axis=-1)
cutoff_index = jnp.sum(sorted_cum_probs < topp, axis=-1, keepdims=True)
cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1)
scaled_logits = jnp.where(scaled_logits < cutoff_logit,
jnp.full_like(scaled_logits, NEG_INF),
scaled_logits)
# [batch]
next_token = random.categorical(rng1, scaled_logits).astype(jnp.int32)
# log probability of the current token conditioned on the previously
# sampled and prefix tokens.
# [batch, vocab] -> [batch, vocab]
if rescale_log_probs:
log_probs = jax.nn.log_softmax(scaled_logits)
else:
log_probs = jax.nn.log_softmax(logits)
# [batch, vocab] -> [batch]
next_log_prob = jnp.squeeze(
jnp.take_along_axis(
log_probs, jnp.expand_dims(next_token, axis=1), axis=-1),
axis=-1)
return (next_token, next_log_prob)
def sample_logits_with_zero_temperature(logits):
# For zero temperature, we always want the greedy output, regardless
# of the values of topk and topp.
next_token = jnp.argmax(logits, -1).astype(jnp.int32)
if rescale_log_probs:
next_log_prob = jnp.zeros_like(next_token, dtype=jnp.float32)
else:
log_probs = jax.nn.log_softmax(logits)
next_log_prob = jnp.squeeze(
jnp.take_along_axis(
log_probs, jnp.expand_dims(next_token, axis=1), axis=-1),
axis=-1)
return (next_token, next_log_prob)
# Perform sampling with temperature
(next_token,
next_log_prob) = lax.cond(temperature > MIN_TEMPERATURE,
sample_logits_with_nonzero_temperature,
sample_logits_with_zero_temperature, logits)
# When different batch elements are at different points in the loop counter,
# it is possible that an element that started at a higher index will reach
# `max_decode_len` before other elements. When this happens we need to make
# sure this element continuous overwrites our new garbage collection index.
# Here we clamp `i` to `max_decode_len`. This will cause the a write to
# `max_decode_len + 1` which is the final index in `sequences`. Subsequent
# loop body executions will also get their value clamped causing continual
# overwriting of the final garbage position until all examples are finished.
i = jnp.minimum(i, max_decode_len)
# Only use sampled tokens if we're past provided prefix tokens.
# Select the next token from sequences.
# [batch]
next_input_token = jnp.squeeze(
jnp.take_along_axis(sequences, jnp.expand_dims(i + 1, axis=1), axis=1),
axis=1)
# Check if the next token is padding (a target) or non-padding (an input).
# Mask will have `1` for targets and `0` for inputs.
out_of_prompt = (next_input_token == 0)
# Select the sampled next token for targets and the actual next token for
# inputs (teacher forcing).
# [batch]
next_token = (
next_token * out_of_prompt + next_input_token * ~out_of_prompt)
# only add probability if outside prefix region
# [batch] -> [batch]
next_log_prob = log_prob + (next_log_prob * out_of_prompt) * jnp.squeeze(
~ended, axis=-1).astype(jnp.int32)
# [batch] -> [batch, 1]
next_token = jnp.expand_dims(next_token, axis=-1)
# If end-marker reached for batch item, only emit padding tokens.
# [batch, 1] * [batch, 1] -> [batch, 1]
next_token_or_endpad = next_token * ~ended
# Add current sampled tokens to recorded sequences.
one_hot = jax.nn.one_hot(i + 1, sequences.shape[1], dtype=sequences.dtype)
new_sequences = sequences * (1 - one_hot) + next_token_or_endpad * one_hot
# new_sequences = dynamic_update_vector_slice_in_dim(sequences,
# next_token_or_endpad,
# i + 1,
# 0)
# Count eos tokens in the sequences and compare to the initial count
# [batch, 1]
cur_eos_count = jnp.sum(new_sequences == end_marker, axis=-1, keepdims=True)
# [batch, 1]
# Have we reached max decoding length?
# We generally index into sequences[:, i + 1], and sequences.shape[1] =
# max_decode_len + 2, therefore i == max_decode_len - 1 will write to
# sequences[-2] which is our last valid location. i == max_decode_len will
# write to sequences[-1] which is our garbage collection token. Thus `i`
# should be strictly less than max_decode_len.
has_additional_eos = cur_eos_count > initial_eos_count
ended |= has_additional_eos | jnp.expand_dims(
i >= max_decode_len - 1, axis=1)
return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2,
next_log_prob)
# Run sampling loop and collect final state.
final_state = lax.while_loop(sampling_loop_cond_fn, sampling_loop_body_fn,
sampling_loop_init_state)
# Pick part of the state corresponding to the sampled sequences.
final_sequences = final_state[1]
log_prob = final_state[-1]
# Drop the first position because they are dummy bos tokens. Drop the new
# garbage collection token at the end too.
return final_sequences[:, 1:-1], log_prob
#------------------------------------------------------------------------------
# BEAM Sampling
#------------------------------------------------------------------------------
def brevity_penalty(alpha: float, length: int) -> jnp.ndarray:
"""Brevity penalty function for beam search penalizing short sequences.
Args:
alpha: float: brevity-penalty scaling parameter.
length: int: length of considered sequence.
Returns:
Brevity penalty score as jax scalar.
"""
return jnp.power(((5.0 + length) / 6.0), alpha)
# Beam handling utility functions:
def cache_map(fn, cache, apply_to_index: bool = False):
"""Maps function over that caches, even multiple caches in various layers.
Args:
fn: The function to apply.
cache: The cache to apply it to.
apply_to_index: Whether to apply the function to the cache index.
Returns:
The result of applying `fn` to the cache.
"""
frozen = isinstance(cache, flax.core.FrozenDict)
if frozen:
cache = flax.core.unfreeze(cache)
flat_cache = traverse_util.flatten_dict(cache)
if apply_to_index:
keyvals = flat_cache
else:
keyvals = {k: v for k, v in flat_cache.items() if k[-1] != 'cache_index'}
# Exclude cached relative position bias from beam expansion, etc.
# Also excludes scalar index in absolute position embedder from expansion.
# TODO(levskaya): generalize cache_map to accept a list of leaf names to
# map over, instead of doing this ad-hoc.
exclusion_list = ['cached_bias', 'position_embedder_index']
keyvals = {k: v for k, v in keyvals.items() if k[-1] not in exclusion_list}
keyvals = jax.tree_map(fn, keyvals)
flat_cache.update(keyvals)
new_cache = traverse_util.unflatten_dict(flat_cache)
if frozen:
new_cache = flax.core.freeze(new_cache)
return new_cache
def add_beam_dim(x: jnp.ndarray,
beam_size: int,
offset: int = 0) -> jnp.ndarray:
"""Creates new beam dimension in non-scalar array and tiles into it."""
x = jnp.expand_dims(x, axis=offset + 1)
tile_dims = [1] * x.ndim
tile_dims[offset + 1] = beam_size
return jnp.tile(x, tile_dims)
def flatten_beam_dim(x: jnp.ndarray, offset: int = 0) -> jnp.ndarray:
"""Flattens the first two dimensions of a non-scalar array."""
xshape = list(x.shape)
b_sz = xshape.pop(offset)
xshape[offset] *= b_sz
return x.reshape(xshape)
def unflatten_beam_dim(x: jnp.ndarray,
batch_size: int,
beam_size: int,
offset: int = 0) -> jnp.ndarray:
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
assert batch_size * beam_size == x.shape[offset]
xshape = list(x.shape)
newshape = xshape[:offset] + [batch_size, beam_size] + xshape[offset + 1:]
return x.reshape(newshape)
def flat_batch_beam_expand(x: jnp.ndarray,
beam_size: int,
offset: int = 0) -> jnp.ndarray:
"""Expands the each batch item by beam_size in batch_dimension."""
return flatten_beam_dim(add_beam_dim(x, beam_size, offset), offset)
def cache_gather_beams(nested: PyTreeDef,
beam_indices: jnp.ndarray,
batch_size: int,
old_beam_size: int,
new_beam_size: int,
one_hot: bool = True,
offset: int = 0) -> jnp.ndarray:
"""Gathers the cache beam slices indexed by beam_indices into new beam array.
Args:
nested: cache pytree.
beam_indices: array of beam_indices
batch_size: size of batch.
old_beam_size: size of _old_ beam dimension.
new_beam_size: size of _new_ beam dimension.
one_hot: whether to perform gathers by one-hot contraction or directly.
offset: cache axis offset from scanned layers.
Returns:
New pytree with new beam arrays.
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
"""
assert offset in (0, 1), 'general offsets not supported'
if one_hot:
# Gather via one-hot contraction, needed for SPMD partitioning.
oh_beam_indices = jax.nn.one_hot(
beam_indices, old_beam_size, dtype=jnp.int32)
if offset == 0:
def gather_fn(x):
return jnp.einsum('beo,bo...->be...', oh_beam_indices,
x).astype(x.dtype)
else:
def gather_fn(x):
return jnp.einsum('beo,lbo...->lbe...', oh_beam_indices,
x).astype(x.dtype)
return cache_map(gather_fn, nested)
else:
# True gather via fancy indexing.
batch_indices = jnp.reshape(
jnp.arange(batch_size * new_beam_size) // new_beam_size,
(batch_size, new_beam_size))
if offset == 0:
def gather_fn(x):
return x[batch_indices, beam_indices]
else:
def gather_fn(x):
return x[:, batch_indices, beam_indices]
return cache_map(gather_fn, nested)
def gather_beams(nested: PyTreeDef,
beam_indices: jnp.ndarray,
batch_size: int,
old_beam_size: int,
new_beam_size: int,
one_hot: bool = True) -> jnp.ndarray:
"""Gathers the beam slices indexed by beam_indices into new beam array.
Args:
nested: pytree of arrays or scalars (the latter ignored).
beam_indices: array of beam_indices
batch_size: size of batch.
old_beam_size: size of _old_ beam dimension.
new_beam_size: size of _new_ beam dimension.
one_hot: whether to perform gathers by one-hot contraction or directly.
Returns:
New pytree with new beam arrays.
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
"""
if one_hot:
# Gather via one-hot contraction, needed for SPMD partitioning.
oh_beam_indices = jax.nn.one_hot(
beam_indices, old_beam_size, dtype=jnp.int32)
def gather_fn(x):
return jnp.einsum('beo,bo...->be...', oh_beam_indices, x).astype(x.dtype)
return jax.tree_map(gather_fn, nested)
else:
# True gather via fancy indexing.
batch_indices = jnp.reshape(
jnp.arange(batch_size * new_beam_size) // new_beam_size,
(batch_size, new_beam_size))
def gather_fn(x):
return x[batch_indices, beam_indices]
return jax.tree_map(gather_fn, nested)
def top_k_two_stage(x, k):
"""Wrapper around lax.top_k with low-batch optimization.
Args:
x: tensor with shape f32[batch, num_samples].
k: integer indicating how many top values to return.
Returns:
Largest k values and indices with shape (f32[batch, k], s32[batch, k]).
"""
batch, num_samples = x.shape
num_lanes = 128
if (isinstance(batch, int) and batch <= 8 and
num_samples > 8 * num_lanes * k):
# At small batch, when num_samples is sufficiently large, optimize
# execution on TPU by doing TopK in two stages. Reshaping 'x' to fill
# lanes reduces tensor padding in TopK call.
if num_samples % num_lanes != 0:
# Pad input tensor to multiples of num_lanes.
num_samples_rounded_up = num_samples + (
num_lanes - num_samples % num_lanes)
x = jnp.pad(
x, ((0, 0), (0, num_samples_rounded_up - num_samples)),
mode='constant',
constant_values=np.NINF)
num_samples = num_samples_rounded_up
# Reshape input tensor to fill lanes.
num_samples_sublanes = int(num_samples / num_lanes)
x_reshaped = jnp.reshape(x, (batch * num_lanes, num_samples_sublanes))
# First stage top_k.
vals, indices = lax.top_k(x_reshaped, k)
indices = jnp.reshape(indices, (batch, num_lanes, k))
index_offsets = jnp.reshape(num_samples_sublanes * jnp.arange(num_lanes),
(1, num_lanes, 1))
indices = jnp.reshape(
jnp.add(index_offsets, indices), (batch, num_lanes * k))
vals = jnp.reshape(vals, (batch, num_lanes * k))
# Second stage top_k.
vals_s2, indices_s2 = lax.top_k(vals, k)
indices_s2 = jnp.take_along_axis(indices, indices_s2, axis=1)
return vals_s2, indices_s2
else:
# Use default TopK implementation.
return lax.top_k(x, k)
def gather_topk_beams(nested: PyTreeDef, score_or_log_prob: jnp.ndarray,
batch_size: int, new_beam_size: int) -> jnp.ndarray:
"""Gathers the top-k beam slices given by score_or_log_prob array.
Args:
nested: pytree of arrays or scalars (the latter ignored).
score_or_log_prob: [batch_size, old_beam_size] array of values to sort by
for top-k selection of beam slices.
batch_size: int: size of batch.
new_beam_size: int: size of _new_ top-k selected beam dimension
Returns:
New pytree with new beam arrays containing top k new_beam_size slices.
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
"""
_, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size)
topk_indices = jnp.flip(topk_indices, axis=1)
return gather_beams(nested, topk_indices, batch_size,
score_or_log_prob.shape[1], new_beam_size)
# Beam search state:
@flax.struct.dataclass
class BeamState:
"""Holds beam search state data."""
# The position of the decoding loop in the length dimension.
cur_index: jnp.DeviceArray # scalar int32: current decoded length index
# The active sequence log probabilities and finished sequence scores.
live_logprobs: jnp.DeviceArray # float32: [batch_size, beam_size]
finished_scores: jnp.DeviceArray # float32: [batch_size, beam_size]
# The current active-beam-searching and finished sequences.
live_seqs: jnp.DeviceArray # int32: [batch_size, beam_size, max_decode_len]
finished_seqs: jnp.DeviceArray # int32: [batch_size, beam_size,
# max_decode_len]
# Records which of the 'finished_seqs' is occupied and not a filler slot.
finished_flags: jnp.DeviceArray # bool: [batch_size, beam_size]
# The current state of the autoregressive decoding caches.
cache: PyTreeDef # Any pytree of arrays, e.g. flax attention Cache object
def beam_init(batch_size: int,
beam_size: int,
max_decode_len: int,
cache: Mapping[str, jnp.ndarray],
offset: int = 0) -> BeamState:
"""Initializes the beam search state data structure."""
cur_index0 = jnp.array(0)
live_logprobs0 = jnp.tile(
jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1])
finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
# add beam dimension to attention cache pytree elements
beam_cache0 = cache_map(lambda x: add_beam_dim(x, beam_size, offset), cache)
return BeamState(
cur_index=cur_index0,
live_logprobs=live_logprobs0,
finished_scores=finished_scores0,
live_seqs=live_seqs0,
finished_seqs=finished_seqs0,
finished_flags=finished_flags0,
cache=beam_cache0)
# Beam search routine:
def beam_search(inputs: jnp.ndarray,
cache: Mapping[str, jnp.ndarray],
tokens_to_logits: Callable[
[jnp.ndarray, Mapping[str, jnp.ndarray]],
Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]],
eos_id: int,
num_decodes: int = 4,
alpha: float = 0.6,
max_decode_len: Optional[int] = None,
decode_rng: Optional[jnp.ndarray] = None,
cache_offset: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Beam search for transformer machine translation.
If `inputs` has non-zero entries, those values are not modified, i.e.,
the sampled values for those positions are discarded. This simulates the
teacher forcing on the prefix positions.
Args:
inputs: array: [batch_size, length] int32 sequence of tokens.
cache: flax attention cache.
tokens_to_logits: fast autoregressive decoder function taking single token
slices and cache and returning next-token logits and updated cache.
eos_id: int: id of end-of-sentence token for target vocabulary.
num_decodes: number of decoded sequences to be returned. This is equivalent
to the number of beams used in the beam search.
alpha: float: scaling factor for brevity penalty.
max_decode_len: int: an optional maximum length of decoded sequence. If
None, it uses `inputs.shape[1]` as `max_decode_len`.
decode_rng: Unused decoder RNG seed.
cache_offset: axis offset for cache, arising from scanned layers.
Returns:
Tuple of:
[batch_size, beam_size, max_decode_len] top-scoring sequences
[batch_size, beam_size] beam-search scores.
"""
del decode_rng
# We liberally annotate shape information for clarity below.
beam_size = num_decodes
batch_size = inputs.shape[0]
end_marker = jnp.array(eos_id)
if max_decode_len is None:
max_decode_len = inputs.shape[1]
# We start with a dummy token in the beginning so extend the maximum length.
max_decode_len += 1
# initialize beam search state
beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len,
cache, cache_offset)
def beam_search_loop_cond_fn(state: BeamState) -> bool:
"""Beam search loop termination condition."""
# Have we reached max decoding length?
# Because we mutate the "i+1" position, we stop one token before the end.
not_at_end = (state.cur_index < max_decode_len - 1)
# Is no further progress in the beam search possible?
# Get the best possible scores from alive sequences.
min_brevity_penalty = brevity_penalty(alpha, max_decode_len)
best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
# Get the worst scores from finished sequences.
worst_finished_scores = jnp.min(
state.finished_scores, axis=1, keepdims=True)
# Mask out scores from slots without any actual finished sequences.
worst_finished_scores = jnp.where(state.finished_flags,
worst_finished_scores, NEG_INF)
# If no best possible live score is better than current worst finished
# scores, the search cannot improve the finished set further.
search_terminated = jnp.all(worst_finished_scores > best_live_scores)
# If we're not at the max decode length, and the search hasn't terminated,
# continue looping.
return not_at_end & (~search_terminated)
def beam_search_loop_body_fn(state: BeamState) -> BeamState:
"""Beam search loop state update function."""
# Collect the current position slice along length to feed the fast
# autoregressive decoder model. Flatten the beam dimension into batch
# dimension for feeding into the model.
# --> [batch * beam, 1]
flat_ids = flatten_beam_dim(
lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index),
(batch_size, beam_size, 1)))
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = cache_map(
functools.partial(flatten_beam_dim, offset=cache_offset), state.cache)
# Call fast-decoder model on current tokens to get next-position logits.
# --> [batch * beam, vocab]
flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)
# unflatten beam dimension
# [batch * beam, vocab] --> [batch, beam, vocab]
logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = cache_map(
lambda x: unflatten_beam_dim(x, batch_size, beam_size, cache_offset),
new_flat_cache)
# Gather log probabilities from logits
candidate_log_probs = jax.nn.log_softmax(logits)
# Add new logprobs to existing prefix logprobs.
# --> [batch, beam, vocab]
log_probs = (
candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2))
# We'll need the vocab size, gather it from the log probability dimension.
vocab_size = log_probs.shape[-1]
# Each item in batch has beam_size * vocab_size candidate sequences.
# For each item, get the top 2*k candidates with the highest log-
# probabilities. We gather the top 2*K beams here so that even if the best
# K sequences reach EOS simultaneously, we have another K sequences
# remaining to continue the live beam search.
beams_to_keep = 2 * beam_size
# Flatten beam and vocab dimensions.
flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size))
# Gather the top 2*K scores from _all_ beams.
# --> [batch, 2*beams], [batch, 2*beams]
topk_log_probs, topk_indices = top_k_two_stage(
flat_log_probs, k=beams_to_keep)
# Append the most probable 2*K token IDs to the top 2*K sequences
# Recover token id by modulo division.
topk_ids = topk_indices % vocab_size
# Force decode `inputs` into topk_ids up until PAD. When `inputs` is all
# PADs this is a no-op.
next_input_token = jnp.expand_dims(
inputs, axis=1).astype(jnp.int32)[:, :, state.cur_index + 1]
out_of_prompt = (next_input_token == 0)
# When forcing prompts, update log probabilities to `0` for the top of the
# beam and -INF for the rest, effectively keeping only one beam alive.
# --> [batch, 2*beams]
inside_prompt_log_probs = jnp.concatenate([
jnp.zeros((batch_size, 1), dtype=topk_log_probs.dtype),
jnp.full_like(topk_log_probs[:, :beams_to_keep - 1], NEG_INF)
],
axis=1)
topk_log_probs = (
topk_log_probs * out_of_prompt +
inside_prompt_log_probs * ~out_of_prompt)
topk_ids = topk_ids * out_of_prompt + next_input_token * ~out_of_prompt
# Expand id array for broadcasting
# --> [batch, 2*beams, 1]
topk_ids = jnp.expand_dims(topk_ids, axis=2)
# Recover the beam index by floor division.
topk_beam_indices = topk_indices // vocab_size
# Gather 2*k top beams.
# --> [batch, 2*beams, length]
topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size,
beam_size, beams_to_keep)
# Update sequences for the 2*K top-k new sequences.
# --> [batch, 2*beams, length]
topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids,
(0, 0, state.cur_index + 1))
# Update LIVE (in-progress) sequences:
# Did any of these sequences reach an end marker?
# --> [batch, 2*beams]
newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
# To prevent these newly finished sequences from being added to the LIVE
# set of active beam search sequences, set their log probs to a very large
# negative value.
new_log_probs = topk_log_probs + newly_finished * NEG_INF
# Determine the top k beam indices (from top 2*k beams) from log probs.
# --> [batch, beams]
_, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
new_topk_indices = jnp.flip(new_topk_indices, axis=1)
# Gather the top k beams (from top 2*k beams).
# --> [batch, beams, length], [batch, beams]
top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs],
new_topk_indices,
batch_size, 2 * beam_size,
beam_size)
# Determine the top k beam indices from the original set of all beams.
# --> [batch, beams]
top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices,
batch_size, 2 * beam_size, beam_size)
# With these, gather the top k beam-associated caches.
# --> {[batch, beams, ...], ...}
top_alive_cache = cache_gather_beams(new_cache, top_alive_indices,
batch_size, beam_size, beam_size, True,
cache_offset)
# Update FINISHED (reached end of sentence) sequences:
# Calculate new seq scores from log probabilities.
new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1)
# Mask out the still unfinished sequences by adding large negative value.
# --> [batch, 2*beams]
new_scores += (~newly_finished) * NEG_INF
# Combine sequences, scores, and flags along the beam dimension and compare
# new finished sequence scores to existing finished scores and select the
# best from the new set of beams.
finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length]
[state.finished_seqs, topk_seq],
axis=1)
finished_scores = jnp.concatenate( # --> [batch, 3*beams]
[state.finished_scores, new_scores], axis=1)
finished_flags = jnp.concatenate( # --> [batch, 3*beams]
[state.finished_flags, newly_finished], axis=1)
# --> [batch, beams, length], [batch, beams], [batch, beams]
top_finished_seq, top_finished_scores, top_finished_flags = (
gather_topk_beams([finished_seqs, finished_scores, finished_flags],
finished_scores, batch_size, beam_size))
return BeamState(
cur_index=state.cur_index + 1,
live_logprobs=top_alive_log_probs,
finished_scores=top_finished_scores,
live_seqs=top_alive_seq,
finished_seqs=top_finished_seq,
finished_flags=top_finished_flags,
cache=top_alive_cache)
# Run while loop and get final beam search state.
final_state = lax.while_loop(beam_search_loop_cond_fn,
beam_search_loop_body_fn, beam_search_init_state)
# Account for the edge-case where there are no finished sequences for a
# particular batch item. If so, return live sequences for that batch item.
# --> [batch]
none_finished = jnp.any(final_state.finished_flags, axis=1)
# --> [batch, beams, length]
finished_seqs = jnp.where(none_finished[:, None, None],
final_state.finished_seqs, final_state.live_seqs)
# --> [batch, beams]
finished_scores = jnp.where(none_finished[:,
None], final_state.finished_scores,
final_state.live_logprobs)
# Drop the first dummy 0 token.
return finished_seqs[:, :, 1:], finished_scores