|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Inference."""
|
|
import functools
|
|
|
|
from typing import Any, Callable, Optional, Tuple
|
|
|
|
import flax
|
|
from flax import linen as nn
|
|
import jax
|
|
from jax import lax
|
|
from jax import numpy as jnp
|
|
|
|
import numpy as np
|
|
|
|
|
|
EOS_ID = 1
|
|
NEG_INF = np.array(-1.0e7)
|
|
|
|
|
|
GenerateFn = Callable[...,
|
|
Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]]]
|
|
|
|
|
|
def temperature_sampling(*args, temperature=1.0, top_k=0, top_p=0.0, **kwargs):
|
|
"""Convenience wrapper for temperature sampling."""
|
|
return generate(*args, generate_fn=_temperature_sampling,
|
|
temperature=temperature,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
**kwargs)
|
|
|
|
|
|
def topk_sampling(*args, temperature=1.0, top_k=20, **kwargs):
|
|
"""Convenience wrapper for top-k sampling."""
|
|
return generate(*args, generate_fn=_temperature_sampling,
|
|
temperature=temperature,
|
|
top_k=top_k,
|
|
top_p=0.0,
|
|
**kwargs)
|
|
|
|
|
|
def nucleus_sampling(*args, temperature=1.0, top_p=0.2, **kwargs):
|
|
"""Convenience wrapper for nucleus sampling."""
|
|
return generate(*args, generate_fn=_temperature_sampling,
|
|
temperature=temperature,
|
|
top_k=0,
|
|
top_p=top_p,
|
|
**kwargs)
|
|
|
|
|
|
def argmax_sampling(*args, **kwargs):
|
|
"""Convenience wrapper for argmax sampling."""
|
|
return generate(*args, generate_fn=_temperature_sampling,
|
|
temperature=1e-7,
|
|
top_k=0,
|
|
top_p=0.0,
|
|
**kwargs)
|
|
|
|
|
|
def generate(params, inputs, prompts, seed, *,
|
|
model: nn.Module,
|
|
generate_fn: GenerateFn,
|
|
num_samples: int = 1,
|
|
prefill: bool = False,
|
|
eos_token: int = EOS_ID,
|
|
**generate_fn_kwargs):
|
|
"""Generate sequence with fast decoding beam search on a batch.
|
|
|
|
Model must support:
|
|
encode(inputs) -> encoded, or encode(*inputs) -> encoded.
|
|
decode(encoded, prompts, decode=True/False, max_decode_length) -> logits
|
|
|
|
Args:
|
|
params: model parameters.
|
|
inputs: either a single `jnp.ndarray` of e.g. images, or
|
|
a tuple of inputs which are passed via `model.encode(*inputs)`.
|
|
prompts: [batch_size, max_decode_len] forced tokens for generation.
|
|
prompts need to finish with 0 token, they should not contain the end
|
|
markers. If no prompting is required, pass an all zeros tensor.
|
|
seed: PRNG key for random sampling.
|
|
model: object with methods encode and decode.
|
|
generate_fn: search or sampling function to generate sequences.
|
|
num_samples: number of samples to generate per item.
|
|
prefill: whether to prefill cache.
|
|
eos_token: if of end-of-sentence token for target vocabulary.
|
|
**generate_fn_kwargs: generate fn specific kwargs.
|
|
|
|
Returns:
|
|
Top-scoring sequences (worst scores first).
|
|
[batch_size, num_samples, max_decode_len]
|
|
Scores of the generated sequences (worst scores first). The
|
|
returned scores are modified log probabilities. May be absent.
|
|
[batch_size, max_decode_len]
|
|
Log probs for the generated tokens. May be absent.
|
|
[batch_size, num_samples, max_decode_len]
|
|
"""
|
|
_, max_decode_len = prompts.shape
|
|
decode_kwargs = {"max_decode_length": max_decode_len}
|
|
|
|
def encode(model, inputs):
|
|
if not isinstance(inputs, tuple):
|
|
inputs = (inputs,)
|
|
return model.encode(*inputs)
|
|
|
|
encoded_inputs = nn.apply(encode, model)(params, inputs)
|
|
if isinstance(encoded_inputs, tuple):
|
|
encoded_inputs, enc_pos_emb = encoded_inputs
|
|
decode_kwargs["enc_pos_emb"] = enc_pos_emb
|
|
|
|
def init_cache(model):
|
|
encoded = jnp.zeros_like(encoded_inputs)
|
|
targets = jnp.zeros_like(prompts)
|
|
return model.decode(encoded, targets, decode=True, **decode_kwargs)
|
|
|
|
cache = nn.apply(init_cache, model, mutable=True)(params)[1]["cache"]
|
|
|
|
def prefill_cache(model, encoded, targets):
|
|
return model.decode(encoded, targets, prefill=True, **decode_kwargs)
|
|
|
|
if prefill:
|
|
cache = nn.apply(prefill_cache, model, mutable=True)(
|
|
{"params": params["params"], "cache": cache},
|
|
encoded_inputs, prompts)[1]["cache"]
|
|
|
|
def tokens_to_logits(tokens, cache):
|
|
def decode_step(model, tokens):
|
|
encoded = expand_samples_dim_and_flatten(
|
|
encoded_inputs, num_samples)
|
|
return model.decode(encoded, tokens, decode=True, **decode_kwargs)
|
|
|
|
logits, aux = nn.apply(decode_step, model, mutable=True)(
|
|
{"params": params["params"], "cache": cache}, tokens)
|
|
return logits.squeeze(axis=1), aux["cache"]
|
|
|
|
beam_seqs, scores, logprobs = generate_fn(
|
|
prompts,
|
|
cache,
|
|
tokens_to_logits,
|
|
num_samples=num_samples,
|
|
eos_token=eos_token,
|
|
max_decode_len=max_decode_len,
|
|
seed=seed,
|
|
**generate_fn_kwargs)
|
|
return beam_seqs, scores, logprobs
|
|
|
|
|
|
def expand_samples_dim(x, num_samples):
|
|
"""Creates new dimension in non-scalar array and tiles into it."""
|
|
if x.ndim == 0:
|
|
return x
|
|
x = jnp.expand_dims(x, axis=1)
|
|
tile_dims = [1] * x.ndim
|
|
tile_dims[1] = num_samples
|
|
return jnp.tile(x, tile_dims)
|
|
|
|
|
|
def flatten_samples_dim(x):
|
|
"""Flattens samples dim into batch dim."""
|
|
if x.ndim == 0:
|
|
return x
|
|
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
|
|
|
|
|
def unflatten_samples_dim(x, batch_size, num_samples):
|
|
"""Unflattens first dim into batch and samples dims."""
|
|
if x.ndim == 0:
|
|
return x
|
|
assert batch_size * num_samples == x.shape[0]
|
|
return x.reshape((batch_size, num_samples) + x.shape[1:])
|
|
|
|
|
|
def expand_samples_dim_and_flatten(x, num_samples):
|
|
"""Expands the each batch item by num_samples in batch dimension."""
|
|
return flatten_samples_dim(expand_samples_dim(x, num_samples))
|
|
|
|
|
|
def cache_map(fn, cache):
|
|
"""Maps function over caches, even multiple caches in various layers."""
|
|
frozen = isinstance(cache, flax.core.FrozenDict)
|
|
if frozen:
|
|
cache = flax.core.unfreeze(cache)
|
|
flat_cache = flax.traverse_util.flatten_dict(cache)
|
|
|
|
keyvals = {k: v for k, v in flat_cache.items() if k[-1] != "cached_bias"}
|
|
keyvals = jax.tree_map(fn, keyvals)
|
|
flat_cache.update(keyvals)
|
|
new_cache = flax.traverse_util.unflatten_dict(flat_cache)
|
|
if frozen:
|
|
new_cache = flax.core.freeze(new_cache)
|
|
return new_cache
|
|
|
|
|
|
@flax.struct.dataclass
|
|
class LoopState:
|
|
"""Internal state of the temperature sampling loop."""
|
|
|
|
cur_index: int
|
|
|
|
cache: Any
|
|
|
|
flags_finished: jnp.ndarray
|
|
|
|
sequences: jnp.ndarray
|
|
scores: jnp.array
|
|
logprobs: jnp.array
|
|
rng: jnp.ndarray
|
|
|
|
|
|
def _init_state(prompts, cache, init_rng_key, num_samples):
|
|
batch_size, max_decode_len_plus_one = prompts.shape
|
|
|
|
cache = cache_map(
|
|
lambda x: expand_samples_dim_and_flatten(x, num_samples), cache)
|
|
return LoopState(
|
|
cur_index=0,
|
|
cache=cache,
|
|
flags_finished=jnp.zeros((batch_size*num_samples), dtype=jnp.bool_),
|
|
sequences=expand_samples_dim_and_flatten(prompts, num_samples),
|
|
scores=jnp.zeros((batch_size*num_samples)),
|
|
logprobs=jnp.zeros((batch_size*num_samples, max_decode_len_plus_one-1)),
|
|
rng=init_rng_key)
|
|
|
|
|
|
def _should_temperature_sampling_continue(state, max_decode_len):
|
|
"""Check if we should continue or not."""
|
|
|
|
max_length_not_reached = state.cur_index < max_decode_len - 1
|
|
all_seqs_finished = jnp.all(state.flags_finished)
|
|
return max_length_not_reached & (~all_seqs_finished)
|
|
|
|
|
|
def _temperature_sampling_iteration(state, tokens_to_logits, temperature, eos,
|
|
top_k, top_p, mask_token_ids=()):
|
|
"""Temperature sampling step function."""
|
|
|
|
rng_sampling, rng = jax.random.split(state.rng)
|
|
|
|
|
|
|
|
|
|
cur_tokens = state.sequences[:, state.cur_index]
|
|
logits, new_cache = tokens_to_logits(cur_tokens[:, None], state.cache)
|
|
assert logits.ndim == 2, ("tokens_to_logits expected to return a"
|
|
f"2-dimensional array [B, V], got {logits.ndim}"
|
|
"dimensions.")
|
|
logprobs = jax.nn.log_softmax(logits)
|
|
|
|
|
|
if mask_token_ids:
|
|
probs = jax.nn.softmax(logits)
|
|
for i in mask_token_ids:
|
|
probs = probs.at[:, i].set(0.)
|
|
probs = probs / jnp.sum(probs, -1, keepdims=True)
|
|
logits = jnp.log(probs)
|
|
|
|
if top_p:
|
|
logits_sorted = jnp.sort(logits, axis=-1)[:, ::-1]
|
|
sorted_cum_probs = jnp.cumsum(
|
|
jax.nn.softmax(logits_sorted, axis=-1), axis=-1)
|
|
cutoff_index = jnp.sum(sorted_cum_probs < top_p, axis=-1, keepdims=True)
|
|
cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1)
|
|
logits = jnp.where(logits < cutoff_logit,
|
|
jnp.full_like(logits, NEG_INF), logits)
|
|
if top_k:
|
|
topk_logits, topk_indices = jax.lax.top_k(logits, top_k)
|
|
topk_token = jax.random.categorical(rng_sampling, topk_logits / temperature)
|
|
sampled_tokens = jnp.squeeze(
|
|
jnp.take_along_axis(topk_indices, jnp.expand_dims(topk_token, -1),
|
|
axis=-1), axis=-1)
|
|
else:
|
|
sampled_tokens = jax.random.categorical(rng_sampling, logits / temperature)
|
|
|
|
sampled_logprobs = jnp.squeeze(jnp.take_along_axis(
|
|
logprobs, jnp.expand_dims(sampled_tokens, axis=1), axis=-1), axis=-1)
|
|
|
|
|
|
|
|
next_tokens = state.sequences[:, state.cur_index + 1]
|
|
next_logprobs = jnp.squeeze(jnp.take_along_axis(
|
|
logprobs, jnp.expand_dims(next_tokens, axis=1), axis=-1), axis=-1)
|
|
out_of_prompt = next_tokens == 0
|
|
update_pos = out_of_prompt * (~state.flags_finished)
|
|
next_tokens = sampled_tokens * update_pos + next_tokens * (~update_pos)
|
|
sampled_logprobs = update_pos*sampled_logprobs + ~update_pos*next_logprobs
|
|
sequences = state.sequences.at[:, state.cur_index + 1].set(next_tokens)
|
|
scores = state.scores + sampled_logprobs
|
|
seqs_logprobs = state.logprobs.at[:, state.cur_index].set(sampled_logprobs)
|
|
|
|
|
|
flags_finished = out_of_prompt & (state.flags_finished |
|
|
(sampled_tokens == eos))
|
|
return LoopState(
|
|
cur_index=state.cur_index+1,
|
|
cache=new_cache,
|
|
flags_finished=flags_finished,
|
|
sequences=sequences,
|
|
scores=scores,
|
|
logprobs=seqs_logprobs,
|
|
rng=rng)
|
|
|
|
|
|
def _temperature_sampling(prompts, cache, tokens_to_logits, num_samples=1,
|
|
eos_token=EOS_ID, max_decode_len=None,
|
|
seed=0, temperature=1., top_k=0, top_p=0.0,
|
|
mask_token_ids=()):
|
|
"""Temperature sampling.
|
|
|
|
Purely stochastic sampling-based greedy procedure to generate sequences. Every
|
|
next token in the sequence is sampled from the discrete vocab distribution
|
|
produced by the auto-regressive sequence model. Optionally we can adjust the
|
|
distribution by changing the temperature before sampling from it. Generated
|
|
sequences are no longer than max_decode_len.
|
|
|
|
Args:
|
|
prompts: optional prompts [B, L]. By default (None), we call free form
|
|
generation without any prompts. Prompt sequences should finish with
|
|
trailing zeros and should not contain eos tokens.
|
|
cache: cache for fast decoding (generation).
|
|
tokens_to_logits: fast autoregressive decoder function taking single token
|
|
slices and cache and returning next-token logits and updated cache.
|
|
num_samples: int: number of samples to generate per batch item. Note, no
|
|
deduplication is performed, and in dependence of parameter settings, same
|
|
sequences could be generated and returned.
|
|
eos_token: end-of-sentence token.
|
|
max_decode_len: maximal length of generated sequences (L).
|
|
seed: PRNGKey for random sampling.
|
|
temperature: positive real-valued sampling temperature. By default we sample
|
|
from the original distribution. As the temperature approaches 0., the
|
|
entire distribution concentrates on the most probable outcome(s).
|
|
top_k: limit sampling to only top-k logits. Zero means no limit.
|
|
top_p: limit sampling to smallest number of top logits with max cumulative
|
|
prob <= top_p. Zero means no limit. Cannot use both top_p and top_k.
|
|
mask_token_ids: if set then tokens with given ids are not sampled.
|
|
|
|
Returns:
|
|
sequences: generated sequences [B, num_samples, L].
|
|
scores: not implemented in the naive temperature sampling [B, num_samples].
|
|
logprobs: Log probabilities for the generated tokens [B, num_samples, L].
|
|
"""
|
|
if top_k > 0 and top_p > 0.0:
|
|
raise ValueError(f"Cannot use both top_k {top_k} and top_p {top_p}.")
|
|
if max_decode_len is None:
|
|
max_decode_len = prompts.shape[1]
|
|
|
|
prompts = jnp.pad(prompts, ((0, 0), (1, 0)))
|
|
eos = jnp.array(eos_token)
|
|
if isinstance(seed, int):
|
|
seed = jax.random.PRNGKey(seed)
|
|
|
|
|
|
loop_init_state = _init_state(prompts, cache, seed, num_samples)
|
|
should_temperature_sampling_continue_fn = functools.partial(
|
|
_should_temperature_sampling_continue,
|
|
max_decode_len=max_decode_len+1)
|
|
temperature_sampling_iteration_fn = functools.partial(
|
|
_temperature_sampling_iteration,
|
|
tokens_to_logits=tokens_to_logits,
|
|
temperature=temperature, top_k=top_k, top_p=top_p,
|
|
eos=eos, mask_token_ids=mask_token_ids)
|
|
|
|
|
|
final_state = lax.while_loop(
|
|
should_temperature_sampling_continue_fn,
|
|
temperature_sampling_iteration_fn,
|
|
loop_init_state)
|
|
|
|
|
|
return (
|
|
final_state.sequences[:, 1:].reshape((-1, num_samples, max_decode_len)),
|
|
final_state.scores.reshape((-1, num_samples)),
|
|
final_state.logprobs.reshape((-1, num_samples, max_decode_len)))
|
|
|