Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for t5x.decoding.""" | |
import functools | |
from typing import Mapping, Tuple | |
from unittest import mock | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
import jax | |
from jax._src import api | |
from jax.experimental import host_callback as hcb | |
import jax.numpy as jnp | |
import numpy as np | |
from t5x import decoding | |
EOS_ID = 1 | |
NEG_INF = decoding.NEG_INF | |
class DecodeTest(parameterized.TestCase): | |
def test_temperature_sample_uneven_prefix(self): | |
def token_to_logits(ids, cache): | |
del ids | |
del cache | |
# Always sample id 2 for batch element 0 and id 3 for element 1. | |
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], | |
dtype=np.float32) | |
return logits, {} | |
inputs = np.array([[0, 5, 7, 1, 0, 0], [0, 6, 1, 0, 0, 0]]) | |
sampled_sequences, _ = decoding._temperature_sample_single_trial( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
jax.random.PRNGKey(0), | |
topk=0, | |
initial_index=np.array([3, 2])) | |
expected = np.array([[5, 7, 1, 2, 2, 2], [6, 1, 3, 3, 3, 3]]) | |
np.testing.assert_array_equal(expected, sampled_sequences) | |
def test_temperature_sample_no_prefix(self): | |
batch, max_decode_len = 2, 3 | |
def token_to_logits(ids, cache): # pylint: disable=unused-argument | |
# Always sample id 2 for batch element 0 and id 3 for element 1. | |
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], | |
dtype=np.float32) | |
return logits, {} | |
inputs = np.zeros((batch, max_decode_len), dtype=np.int32) | |
sampled_sequences, _ = decoding._temperature_sample_single_trial( | |
inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0) | |
expected = [[2, 2, 2], [3, 3, 3]] | |
np.testing.assert_array_equal(expected, sampled_sequences) | |
def test_temperature_sample_prefix(self): | |
def token_to_logits(ids, cache): # pylint: disable=unused-argument | |
# Always sample id 2 for batch element 0 and id 3 for element 1. | |
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], | |
dtype=np.float32) | |
return logits, {} | |
# batch element 0 has length 3 prefix and element 1 has length 2. | |
inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) | |
sampled_sequences, _ = decoding._temperature_sample_single_trial( | |
inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0) | |
expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]] | |
np.testing.assert_array_equal(expected, sampled_sequences) | |
def test_temperature_sample_with_zero_temperature(self): | |
batch, max_decode_len = 2, 3 | |
def token_to_logits(ids, cache): # pylint: disable=unused-argument | |
# Use very large logits that are close to one another. | |
logits = np.array( | |
[[1700.47, 1700.48, 1700.51, 1700.45], [3.2, 4.8, -5.3, 5.6]], | |
dtype=np.float32) | |
return logits, {} | |
inputs = np.zeros((batch, max_decode_len), dtype=np.int32) | |
sampled_sequences, _ = decoding._temperature_sample_single_trial( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
jax.random.PRNGKey(0), | |
topk=4, | |
temperature=0.0) | |
expected = [[2, 2, 2], [3, 3, 3]] | |
np.testing.assert_array_equal(expected, sampled_sequences) | |
def test_temperature_sample_prefix_ending_with_eos(self): | |
def token_to_logits(ids, cache): # pylint: disable=unused-argument | |
# Always sample id 2 for batch element 0 and id 3 for element 1. | |
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], | |
dtype=np.float32) | |
return logits, {} | |
# batch element 0 has length 4 prefix (including the initial dummy token and | |
# the last eos) and element 1 has length 3. | |
inputs = np.array([[0, 5, 6, 1, 0], [0, 8, 1, 0, 0]], dtype=np.int32) | |
sampled_sequences, _ = decoding._temperature_sample_single_trial( | |
inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=1) | |
expected = [[5, 6, 1, 2, 2], [8, 1, 3, 3, 3]] | |
np.testing.assert_array_equal(expected, sampled_sequences) | |
def test_temperature_sample_with_state_callback(self): | |
def token_to_logits(ids, cache): # pylint: disable=unused-argument | |
# A distribution with roughly all probability mass in sample id 3 | |
logits = np.array([[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], | |
dtype=np.float32) | |
return logits, {} | |
def state_callback_fn(state): | |
i, sequences, cache, cur_token, ended, rng, log_prob = state | |
def callback_fn(current_index_and_sequences): | |
"""Add EOS token after first time token id 3 has been sampled.""" | |
current_index, sequences = current_index_and_sequences | |
sequences = np.array(sequences) | |
for i in range(len(current_index)): | |
if sequences[i, current_index[i]] == 3: | |
sequences[i, current_index[i] + 1] = EOS_ID | |
return sequences | |
sequences = hcb.call( | |
callback_fn, (i, sequences), | |
result_shape=api.ShapeDtypeStruct(sequences.shape, sequences.dtype)) | |
return i, sequences, cache, cur_token, ended, rng, log_prob | |
inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) | |
sampled_sequences, _ = decoding._temperature_sample_single_trial( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
jax.random.PRNGKey(0), | |
topk=0, | |
temperature=0.0, | |
state_callback_fn=state_callback_fn) | |
expected = [[5, 6, 7, 3, EOS_ID], [8, 9, 3, EOS_ID, 0]] | |
np.testing.assert_array_equal(expected, sampled_sequences) | |
def test_temperature_sample_with_logit_callback(self): | |
def token_to_logits(ids, cache): # pylint: disable=unused-argument | |
# uniform distribution over targets from model | |
logits = np.array([[-1e7, -1e7, -1e7, -1e7], [-1e7, -1e7, -1e7, -1e7]], | |
dtype=np.float32) | |
return logits, {} | |
def logit_callback_fn(logits, state): | |
del state # unused | |
# Rewrite logits to always sample id 2 for batch element 0 and | |
# id 3 for element 1. | |
logits[0, 2] = 0 | |
logits[1, 3] = 0 | |
return logits | |
# batch element 0 has length 3 prefix and element 1 has length 2. | |
inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) | |
sampled_sequences, _ = decoding._temperature_sample_single_trial( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
jax.random.PRNGKey(0), | |
topk=0, | |
temperature=0.0, | |
logit_callback_fn=logit_callback_fn) | |
expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]] | |
np.testing.assert_array_equal(expected, sampled_sequences) | |
def test_temperature_sample_prefix_ending_with_eos_early_stop(self): | |
batch, max_decode_len = 2, 7 | |
rng0 = jax.random.PRNGKey(0) | |
ret = [np.array([2, 3]) for _ in range(max_decode_len)] | |
# Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of | |
# `decoding._temperature_sample_single_trial`. | |
ret[3] = np.array([2, 1]) | |
# Sequence 0 outputs EOS=1 when i = 4. | |
ret[4] = np.array([1, 3]) | |
ret = jax.numpy.array(ret) | |
def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument | |
"""Ignores logit and returns only based on the rng_input.""" | |
rng = rng0 | |
k = 0 | |
# Mimic the rng split done in `decoding.sample_loop_body_fn`. | |
for j in range(max_decode_len): | |
rng1, rng = jax.random.split(rng) | |
# We want to sift out `j` for which rng1 == rng_input | |
# rngs are a pair of ints. So sum the bool and divide by 2. | |
k += j * (rng1 == rng_input).sum() // 2 | |
# `k` at this point is equal to the while loop variable `i` of the caller. | |
return ret[k] | |
def token_to_logits(ids, cache): # pylint: disable=unused-argument | |
# These values are not used in this test because random.categorical is | |
# directly mocked. | |
dummy_logits = np.zeros((batch, 4), dtype=np.float32) | |
return dummy_logits, {} | |
inputs = np.array([[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]], | |
dtype=np.int32) | |
with mock.patch.object(jax.random, 'categorical', new=mocked_categorical): | |
sampled_sequences, _ = decoding._temperature_sample_single_trial( | |
inputs, {}, token_to_logits, EOS_ID, rng0, topk=0) | |
expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]] | |
np.testing.assert_array_equal(expected, sampled_sequences) | |
def test_greedy_decoding_topk_sample_log_probs(self): | |
def token_to_logits(ids, cache): # pylint: disable=unused-argument | |
# Sample [2, 3] with probability [0.6, 0.4]. | |
logits = np.array([[-1e7, -1e7, -0.510825624, -0.916290732]], | |
dtype=np.float32) | |
return logits, {} | |
inputs = np.array([[0, 2, 2, 2, 0]], dtype=np.int32) | |
sampled_sequences, sampled_log_probs = decoding._temperature_sample_single_trial( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
jax.random.PRNGKey(0), | |
topk=1, | |
rescale_log_probs=True) | |
expected_sequence = [[2, 2, 2, 2, 2]] | |
expected_log_probs = [0.0] | |
np.testing.assert_array_equal(expected_sequence, sampled_sequences) | |
np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs) | |
inputs = np.array([[0, 2, 2, 3, 0]], dtype=np.int32) | |
sampled_sequences, sampled_log_probs = decoding._temperature_sample_single_trial( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
jax.random.PRNGKey(0), | |
topk=1, | |
rescale_log_probs=False) | |
expected_sequence = [[2, 2, 3, 2, 2]] | |
expected_log_probs = [-1.02165125] | |
np.testing.assert_array_equal(expected_sequence, sampled_sequences) | |
np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs) | |
def test_temperature_sample_log_prob(self): | |
batch, max_decode_len = 2, 7 | |
rng0 = jax.random.PRNGKey(0) | |
ret = [np.array([2, 3]) for _ in range(max_decode_len)] | |
# Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of | |
# `decoding._temperature_sample_single_trial`. | |
ret[3] = np.array([2, 1]) | |
# Sequence 0 outputs EOS=1 when i = 4. | |
ret[4] = np.array([1, 3]) | |
ret = jax.numpy.array(ret) | |
# TODO(hwchung): refactor this. | |
def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument | |
"""Ignores logit and returns only based on the rng_input.""" | |
rng = rng0 | |
k = 0 | |
# Mimic the rng split done in `decoding.sample_loop_body_fn`. | |
for j in range(max_decode_len): | |
rng1, rng = jax.random.split(rng) | |
# We want to sift out `j` for which rng1 == rng_input | |
# rngs are a pair of ints. So sum the bool and divide by 2. | |
k += j * (rng1 == rng_input).sum() // 2 | |
# `k` at this point is equal to the while loop variable `i` of the caller. | |
return ret[k] | |
logits = np.random.randn(batch, 4) | |
token_to_logits = lambda ids, cache: (logits, {}) | |
inputs = np.array([[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]], | |
dtype=np.int32) | |
with mock.patch.object(jax.random, 'categorical', new=mocked_categorical): | |
sampled_sequences, log_prob = decoding._temperature_sample_single_trial( | |
inputs, {}, token_to_logits, EOS_ID, rng0, topk=0) | |
log_probs = jax.nn.log_softmax(logits) | |
expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]] | |
expected_log_prob = [ | |
log_probs[0, 2] + log_probs[0, 2] + log_probs[0, 1], | |
log_probs[1, 3] + log_probs[1, 3] + log_probs[1, 1] | |
] | |
expected_log_prob = np.array(expected_log_prob) | |
np.testing.assert_array_equal(expected, sampled_sequences) | |
np.testing.assert_allclose(expected_log_prob, log_prob, atol=1e-5) | |
def test_temperature_sample_num_decodes(self): | |
num_decodes = 3 | |
rng0 = jax.random.PRNGKey(0) | |
inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32) | |
with mock.patch.object(decoding, | |
'_temperature_sample_single_trial') as mocked: | |
# expanded_decodes: [batch * num_decodes, max_decode_len] | |
expanded_decodes = np.array([[5, 1, 4, 4], [5, 1, 5, 5], [5, 1, 3, 3], | |
[8, 7, 5, 5], [8, 7, 3, 3], [8, 7, 4, 4]]) | |
# expanded_log_prob: [batch * num_decodes] | |
expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9]) | |
mocked.return_value = expanded_decodes, expanded_log_prob | |
decodes, scores = decoding.temperature_sample( | |
inputs, {}, mock.Mock(), EOS_ID, rng0, num_decodes=num_decodes) | |
expanded_inputs = jnp.array([[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0], | |
[0, 8, 7, 0], [0, 8, 7, 0], [0, 8, 7, 0]]) | |
# Test that the actual decode function is called with the expanded values. | |
np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs) | |
np.testing.assert_array_equal(decodes, | |
[[[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]], | |
[[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]]]) | |
np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]]) | |
def test_temperature_sample_num_decodes_with_initial_index(self): | |
num_decodes = 3 | |
rng0 = jax.random.PRNGKey(0) | |
inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32) | |
initial_index = np.array([1, 2], dtype=np.int32) | |
with mock.patch.object(decoding, | |
'_temperature_sample_single_trial') as mocked: | |
with mock.patch.object(decoding, 'cache_map') as mocked_cache_map: | |
# expanded_decodes: [batch * num_decodes, max_decode_len] | |
expanded_decodes = np.array([[5, 1, 4, 4], [5, 1, 5, 5], [5, 1, 3, 3], | |
[8, 7, 5, 5], [8, 7, 3, 3], [8, 7, 4, 4]]) | |
# expanded_log_prob: [batch * num_decodes] | |
expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9]) | |
mocked.return_value = expanded_decodes, expanded_log_prob | |
decodes, scores = decoding.temperature_sample( | |
inputs, {}, | |
mock.Mock(), | |
EOS_ID, | |
rng0, | |
num_decodes=num_decodes, | |
initial_index=initial_index) | |
expanded_inputs = jnp.array([[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0], | |
[0, 8, 7, 0], [0, 8, 7, 0], [0, 8, 7, 0]]) | |
expanded_initial_index = np.array([1, 1, 1, 2, 2, 2], dtype=np.int32) | |
# Test that the actual decode function is called with the expanded | |
# values. | |
np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs) | |
np.testing.assert_array_equal(mocked.call_args[1]['initial_index'], | |
expanded_initial_index) | |
# Test that the function was applied to the index in the cache map | |
self.assertTrue(mocked_cache_map.call_args[1]['apply_to_index']) | |
np.testing.assert_array_equal(decodes, | |
[[[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]], | |
[[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]]]) | |
np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]]) | |
def test_temperature_sample_max_decode_steps_with_initial_index( | |
self, initial_index, expected_calls): | |
max_decode_steps = 4 | |
rng0 = jax.random.PRNGKey(0) | |
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]], | |
dtype=np.int32) | |
token_to_logits = mock.Mock() | |
token_to_logits.return_value = (np.array( | |
[[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {}) | |
# to unroll while loop | |
with jax.disable_jit(): | |
decodes, scores = decoding.temperature_sample( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
rng0, | |
initial_index=initial_index, | |
topk=4, | |
max_decode_steps=max_decode_steps) | |
self.assertLen(token_to_logits.call_args_list, expected_calls) | |
expected_output = np.array([[2, 3, 3, 3, 3, 0, 0, 0], | |
[2, 2, 3, 3, 3, 3, 0, 0]]) | |
expected_output = jnp.expand_dims(expected_output, 1) | |
np.testing.assert_array_equal(decodes, expected_output) | |
np.testing.assert_array_equal(scores, [[0.], [0.]]) | |
def test_temperature_sample_max_decode_steps_endpad(self): | |
max_decode_steps = 4 | |
rng0 = jax.random.PRNGKey(0) | |
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 2, 2, 2, 2, 0], | |
[0, 2, 2, 2, 0, 0, 0, 0]], | |
dtype=np.int32) | |
initial_index = np.array([1, 6, 0]) | |
token_to_logits = mock.Mock() | |
token_to_logits.return_value = (np.array( | |
[[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], | |
dtype=np.float32), {}) | |
# to unroll while loop | |
with jax.disable_jit(): | |
decodes, scores = decoding.temperature_sample( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
rng0, | |
initial_index=initial_index, | |
topk=4, | |
max_decode_steps=max_decode_steps) | |
# `inputs[2]` starts from index 0. So it requires 3 calls to | |
# `token_to_logits` to exit the prompt (these generated tokens are | |
# overridden) and 4 more calls to fill the rest. `inputs[0]` only need 4 | |
# calls. In the last 3 calls, it generates but MUST NOT populate the | |
# sequences because it is already ended. | |
self.assertLen(token_to_logits.call_args_list, 7) | |
expected_output = np.array( | |
[[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 2, 2, 2, 2, 3, 3], | |
[2, 2, 2, 3, 3, 3, 3, 0]], | |
dtype=np.int32) | |
expected_output = jnp.expand_dims(expected_output, 1) | |
np.testing.assert_array_equal(decodes, expected_output) | |
np.testing.assert_allclose(scores, [[0.], [0.], [0.]]) | |
def test_temperature_sample_max_decode_steps_docstring_ex4(self): | |
max_decode_steps = 2 | |
rng0 = jax.random.PRNGKey(0) | |
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 3, 4, 0, 0, 0, 0, 0]], | |
dtype=np.int32) | |
initial_index = np.array([1, 2]) | |
token_to_logits = mock.Mock() | |
token_to_logits.return_value = (np.array( | |
[[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {}) | |
# to unroll while loop | |
with jax.disable_jit(): | |
decodes, _ = decoding.temperature_sample( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
rng0, | |
initial_index=initial_index, | |
topk=4, | |
max_decode_steps=max_decode_steps) | |
self.assertLen(token_to_logits.call_args_list, 2) | |
expected_output = np.array( | |
[[2, 2, 2, 0, 0, 0, 0, 0], [3, 4, 3, 3, 0, 0, 0, 0]], dtype=np.int32) | |
expected_output = jnp.expand_dims(expected_output, 1) | |
np.testing.assert_array_equal(decodes, expected_output) | |
def test_temperature_sample_max_decode_steps_hard_limit(self): | |
max_decode_steps = 10 | |
max_decode_steps_hard_limit = 4 | |
rng0 = jax.random.PRNGKey(0) | |
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]], | |
dtype=np.int32) | |
token_to_logits = mock.Mock() | |
token_to_logits.return_value = (np.array( | |
[[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {}) | |
# to unroll while loop | |
with jax.disable_jit(): | |
decodes, scores = decoding.temperature_sample( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
rng0, | |
topk=4, | |
max_decode_steps=max_decode_steps, | |
max_decode_steps_hard_limit=max_decode_steps_hard_limit) | |
expected_output = np.array([[2, 3, 3, 3, 3, 0, 0, 0], | |
[2, 2, 3, 3, 3, 3, 0, 0]]) | |
expected_output = jnp.expand_dims(expected_output, 1) | |
np.testing.assert_array_equal(decodes, expected_output) | |
np.testing.assert_array_equal(scores, [[0.], [0.]]) | |
def test_temperature_sample_topp(self): | |
rng0 = jax.random.PRNGKey(0) | |
inputs = np.zeros((1, 20), dtype=np.int32) | |
token_to_logits = mock.Mock() | |
# logits correspond to (0.3, 0, 0.1, 0.6) | |
token_to_logits.return_value = (np.array([[-1.2, -1e7, -2.3, -0.51]], | |
dtype=np.float32), {}) | |
decodes, scores = decoding.temperature_sample( | |
inputs, {}, token_to_logits, EOS_ID, rng0, topp=0.55, | |
topk=0) # anything under 0.6 will trigger deterministic decoding. | |
expected_output = np.array([[3] * 20]) | |
expected_output = jnp.expand_dims(expected_output, 1) | |
np.testing.assert_array_equal(decodes, expected_output) | |
np.testing.assert_array_equal(scores, [[0.]]) | |
# temperature is applied first, so the distribution becomes | |
# (0.27, 0, 0.069, 0.65), so if topp is 0.63, it should become greedy. | |
decodes, scores = decoding.temperature_sample( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
rng0, | |
temperature=0.8, | |
topp=0.63, | |
topk=0) | |
expected_output = np.array([[3] * 20]) | |
expected_output = jnp.expand_dims(expected_output, 1) | |
np.testing.assert_array_equal(decodes, expected_output) | |
np.testing.assert_array_equal(scores, [[0.]]) | |
def test_dynamic_topp_max_decode_steps(self): | |
rng0 = jax.random.PRNGKey(0) | |
inputs = np.zeros((1, 20), dtype=np.int32) | |
token_to_logits = mock.Mock() | |
# logits correspond to (0.3, 0, 0.1, 0.6) | |
token_to_logits.return_value = (np.array([[-1.2, -1e7, -2.3, -0.51]], | |
dtype=np.float32), {}) | |
def dynamic_decode_fn(inputs, temperature, topp, max_decode_steps): | |
return decoding.temperature_sample( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
rng0, | |
temperature=temperature, | |
topp=topp, | |
topk=0, | |
max_decode_steps=max_decode_steps) | |
dynamic_decode_fn_jit = jax.jit(dynamic_decode_fn) | |
decodes, scores = dynamic_decode_fn_jit(inputs, 0.8, 0.63, 10) | |
expected_output = np.array([[3] * 10 + [0] * 10]) | |
expected_output = jnp.expand_dims(expected_output, 1) | |
np.testing.assert_array_equal(decodes, expected_output) | |
np.testing.assert_array_equal(scores, [[0.]]) | |
def test_topp_log_probs(self): | |
rng0 = jax.random.PRNGKey(0) | |
inputs = np.zeros((1, 1), dtype=np.int32) | |
token_to_logits = mock.Mock() | |
# logits correspond to (0.3, 0, 0.1, 0.6) | |
token_to_logits.return_value = (np.array([[-1.2, NEG_INF, -2.3, -0.51]], | |
dtype=np.float32), {}) | |
with jax.disable_jit(): | |
# this lets us see logits after topp and topk are applied | |
with mock.patch.object(jax.random, 'categorical') as mocked: | |
mocked.return_value = jnp.array([0], dtype=jnp.int32) | |
decodes, _ = decoding.temperature_sample( | |
inputs, {}, | |
token_to_logits, | |
EOS_ID, | |
rng0, | |
temperature=1.4, | |
topp=0.7, | |
topk=0) | |
self.assertLen(token_to_logits.call_args_list, 1) | |
np.testing.assert_array_equal(decodes, jnp.asarray([[[0]]])) | |
np.testing.assert_array_almost_equal( | |
mocked.call_args_list[0][0][1], | |
jnp.asarray([[-0.85714293, NEG_INF, NEG_INF, -0.36428571]])) | |
def test_add_beam_dim(self): | |
x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32) | |
y = decoding.add_beam_dim(x, beam_size=3) | |
self.assertEqual(y.shape, (2, 3, 4)) | |
np.testing.assert_array_equal([[[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0]], | |
[[0, 8, 6, 9], [0, 8, 6, 9], [0, 8, 6, 9]]], | |
y) | |
def test_flat_batch_beam_expand(self): | |
x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32) | |
np.testing.assert_array_equal( | |
[[0, 5, 1, 0], [0, 5, 1, 0], [0, 8, 6, 9], [0, 8, 6, 9]], | |
decoding.flat_batch_beam_expand(x, beam_size=2)) | |
def test_top_k_two_stage(self): | |
def _test_top_k(batch_size, k): | |
# Pick sufficiently large seq_len. | |
seq_len = 2047 * k * batch_size | |
seq = np.arange(seq_len) | |
np.random.shuffle(seq) | |
x = jnp.reshape(seq, (batch_size, int(seq_len / batch_size))).astype( | |
jnp.float32) | |
np.testing.assert_almost_equal( | |
decoding.top_k_two_stage(x, k), jax.lax.top_k(x, k), decimal=5) | |
# Test small batch cases (batch={1,8}, k=16). | |
_test_top_k(1, 16) | |
_test_top_k(8, 16) | |
# Test large batch cases (batch={9,32}, k=11). | |
_test_top_k(9, 11) | |
_test_top_k(32, 11) | |
def test_cache_map(self): | |
cache = { | |
'layers_0': { | |
'cached_key': jnp.ones([3, 6]), | |
'cached_values': jnp.ones([3, 6]), | |
'cache_index': jnp.ones([ | |
3, | |
]), | |
}, | |
'layers_1': { | |
'self_attention': { | |
'cached_key': jnp.ones([2, 7]), | |
'cached_values': jnp.ones([5, 8]), | |
'cache_index': jnp.array(1), | |
}, | |
'encoder_decoder_attention': { | |
'cached_key': jnp.ones([10, 12, 2]), | |
'cached_values': jnp.ones([4, 7, 2]), | |
'cache_index': jnp.ones([4, 5, 6]), | |
} | |
}, | |
} | |
fn = functools.partial(jnp.add, 4) | |
gold_cache = { | |
'layers_0': { | |
'cached_key': fn(jnp.ones([3, 6])), | |
'cached_values': fn(jnp.ones([3, 6])), | |
'cache_index': jnp.ones([ | |
3, | |
]), | |
}, | |
'layers_1': { | |
'self_attention': { | |
'cached_key': fn(jnp.ones([2, 7])), | |
'cached_values': fn(jnp.ones([5, 8])), | |
'cache_index': jnp.array(1), | |
}, | |
'encoder_decoder_attention': { | |
'cached_key': fn(jnp.ones([10, 12, 2])), | |
'cached_values': fn(jnp.ones([4, 7, 2])), | |
'cache_index': jnp.ones([4, 5, 6]), | |
} | |
} | |
} | |
jax.tree_multimap(np.testing.assert_array_equal, | |
decoding.cache_map(fn, cache), gold_cache) | |
def test_cache_map_with_index(self): | |
cache = { | |
'layers_0': { | |
'cached_key': jnp.ones([3, 6]), | |
'cached_values': jnp.ones([3, 6]), | |
'cache_index': jnp.ones([ | |
3, | |
]), | |
}, | |
'layers_1': { | |
'relpos_bias': { | |
'cached_bias': jnp.ones([1, 5, 3]), | |
}, | |
'self_attention': { | |
'cached_key': jnp.ones([2, 7]), | |
'cached_values': jnp.ones([5, 8]), | |
'cache_index': jnp.array(1), | |
}, | |
'encoder_decoder_attention': { | |
'cached_key': jnp.ones([10, 12, 2]), | |
'cached_values': jnp.ones([4, 7, 2]), | |
'cache_index': jnp.ones([4, 5, 6]), | |
} | |
}, | |
'position_embedder': { | |
'position_embedder_index': jnp.array(-1), | |
}, | |
} | |
fn = functools.partial(jnp.add, 8) | |
gold_cache = { | |
'layers_0': { | |
'cached_key': fn(jnp.ones([3, 6])), | |
'cached_values': fn(jnp.ones([3, 6])), | |
'cache_index': fn(jnp.ones([ | |
3, | |
])), | |
}, | |
'layers_1': { | |
'relpos_bias': { | |
'cached_bias': jnp.ones([1, 5, 3]), | |
}, | |
'self_attention': { | |
'cached_key': fn(jnp.ones([2, 7])), | |
'cached_values': fn(jnp.ones([5, 8])), | |
'cache_index': fn(jnp.array(1)), | |
}, | |
'encoder_decoder_attention': { | |
'cached_key': fn(jnp.ones([10, 12, 2])), | |
'cached_values': fn(jnp.ones([4, 7, 2])), | |
'cache_index': fn(jnp.ones([4, 5, 6])), | |
} | |
}, | |
'position_embedder': { | |
'position_embedder_index': jnp.array(-1), | |
}, | |
} | |
jax.tree_multimap(np.testing.assert_array_equal, | |
decoding.cache_map(fn, cache, apply_to_index=True), | |
gold_cache) | |
def test_beam_search(self): | |
# Toy problem, we have 4 states, A, B, START, END, (plus PAD). | |
# Scores are given by a first-order Markov model. | |
batch_size = 2 | |
beam_size = 2 | |
# PAD doesn't matter for this test, but part of the contract for beam_search | |
# is giving the PAD token id 0. | |
states = ['PAD', 'A', 'B', 'START-', '-END'] | |
num_states = len(states) | |
decode_length = 7 | |
# Edge potentials (written inside edges for diagonals): | |
# 1 -1 1 -1 | |
# A ---- A ---- A ---- A ---- A | |
# 0 \ -1 \ 1 \ -1 \ 1 0 | |
# START X X X X END | |
# 0 / -1 / 1 / -1 / 1 0 | |
# B ---- B ---- B ---- B ---- B | |
# 1 -1 1 -1 | |
# put the above edge potentials in a 3-tensor | |
ab_edge_potentials = np.asarray([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]], | |
[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) | |
# now we have to add on the START, END states | |
# and PAD at 0 | |
edge_potentials = np.ones([6, 5, 5]) * NEG_INF | |
edge_potentials[1:5, 1:3, 1:3] = ab_edge_potentials | |
# START can go to either A or B for free at t0 | |
edge_potentials[0, 3, 1] = 0 | |
edge_potentials[0, 3, 2] = 0 | |
# either A or B can go to END for free at t5 | |
edge_potentials[5, 1, 4] = 0 | |
edge_potentials[5, 2, 4] = 0 | |
# PAD can go to anything for free (doesn't matter for this test) | |
edge_potentials[:, 0, :] = 0 | |
edge_potentials = jnp.asarray(edge_potentials) | |
# at time 0, we start with state=START=3 | |
logits0 = jnp.asarray([NEG_INF, NEG_INF, NEG_INF, 0, NEG_INF]) | |
# add dummy flattened batch x beam dim for broadcasting | |
logits0 = jnp.expand_dims(logits0, axis=0) | |
edge_potentials = jnp.expand_dims(edge_potentials, axis=0) | |
def tokens_to_logits( | |
token_indices: jnp.ndarray, state_cache: Mapping[str, jnp.ndarray] | |
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: | |
cur_iter = state_cache['cur_iter'] | |
# grab edge potentials for the current timestep | |
cur_edge_potentials = jnp.take_along_axis( | |
edge_potentials, | |
jnp.reshape( | |
jnp.maximum(0, cur_iter[:, 0].astype(jnp.int32) - 1), | |
(batch_size * beam_size, 1, 1, 1)), | |
axis=1) | |
cur_edge_potentials = jnp.squeeze(cur_edge_potentials, axis=1) | |
# get "logits" from edge potentials for requested tokens (except at t0) | |
cur_logits = jnp.matmul( | |
jnp.reshape( | |
jax.nn.one_hot(token_indices, num_states, axis=1), | |
(batch_size * beam_size, 1, num_states)), cur_edge_potentials) | |
cur_logits = jnp.squeeze(cur_logits, axis=1) | |
# use our START-only logits for t0, otherwise use the edge potentials | |
logits_for_tokens = jnp.where(cur_iter == 0, logits0, cur_logits) | |
# update state in the cache | |
new_cache = state_cache.copy() | |
new_cache['cur_iter'] = cur_iter + 1 | |
return logits_for_tokens, new_cache | |
init_cache = {} | |
init_cache['cur_iter'] = jnp.zeros((batch_size, 1)) | |
top_scoring, _ = decoding.beam_search( | |
inputs=np.zeros([batch_size, decode_length]), | |
cache=init_cache, | |
tokens_to_logits=tokens_to_logits, | |
eos_id=4, | |
num_decodes=beam_size, | |
alpha=0.0, | |
max_decode_len=decode_length) | |
# The two top scoring sequences should be a tie between | |
# START-AABBA-END | |
# and | |
# START-BBAAB-END | |
# (and greedy beam search will find both these with just two beams) | |
top_scoring_strings = [ | |
''.join(states[tok] | |
for tok in top_scoring[0, i, :]) | |
for i in range(beam_size) | |
] | |
expected = ['START-AABBA-END', 'START-BBAAB-END'] | |
np.testing.assert_array_equal(expected, top_scoring_strings) | |
def test_beam_search_force_decode_prefix(self): | |
beam_size = 2 | |
def token_to_logits(ids, cache): # pylint: disable=unused-argument | |
# Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. | |
logits = np.repeat( | |
np.expand_dims( | |
np.array([[-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], | |
[-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4]], | |
dtype=np.float32), | |
axis=1), [beam_size], | |
axis=1) | |
logits = decoding.flatten_beam_dim(logits) | |
return logits, {} | |
# batch element 0 has length 1 and element 1 has length 2. | |
inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32) | |
rolled_inputs = np.array([[7, 0, 0, 0, 0], [4, 5, 0, 0, 0]], dtype=np.int32) | |
beam_search_sequences, decoding_scores = decoding.beam_search( | |
inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size, alpha=0) | |
# Prefixes are forced depending on inputs. | |
# Beam search sequences and corresponding scores are in reverse order. | |
self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) | |
expected = np.array([[[7, 3, 2, 2, 2], [7, 2, 2, 2, 2]], | |
[[4, 5, 2, 3, 3], [4, 5, 3, 3, 3]]]) | |
np.testing.assert_array_equal(expected, beam_search_sequences) | |
expected_scores = [] | |
batch_logits = np.array([[-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], | |
[-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4]], | |
dtype=np.float32) | |
for batch, logits, prompt in zip(expected, batch_logits, rolled_inputs): | |
beam_expected_scores = [] | |
for beam in batch: | |
log_probs = jax.nn.log_softmax(logits) | |
# Add them directly since they are static. | |
beam_scores = [] | |
for token, prompt_token in zip(beam, prompt): | |
if prompt_token != 0: | |
beam_scores.append(0) | |
else: | |
beam_scores.append(log_probs[token]) | |
beam_expected_scores.append(sum(beam_scores)) | |
expected_scores.append(beam_expected_scores) | |
np.testing.assert_allclose(expected_scores, decoding_scores, atol=1e-5) | |
def test_beam_search_force_decode_no_prefix(self): | |
beam_size = 2 | |
def token_to_logits(ids, cache): # pylint: disable=unused-argument | |
# Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. | |
logits = np.repeat( | |
np.expand_dims( | |
np.array([[-1e7, -1e10, -0.1, -0.9], [-1e7, -1e10, -0.9, -0.1]], | |
dtype=np.float32), | |
axis=1), [beam_size], | |
axis=1) | |
logits = decoding.flatten_beam_dim(logits) | |
return logits, {} | |
# No prefix is passed. | |
inputs = np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=np.int32) | |
beam_search_sequences, decoding_scores = decoding.beam_search( | |
inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size) | |
# Prefixes are forced depending on inputs. | |
# Beam search sequences and corresponding scores are in reverse order. | |
self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) | |
expected = np.array([[[3, 2, 2, 2, 2], [2, 2, 2, 2, 2]], | |
[[2, 3, 3, 3, 3], [3, 3, 3, 3, 3]]]) | |
np.testing.assert_array_equal(expected, beam_search_sequences) | |
if __name__ == '__main__': | |
absltest.main() | |