youtube-music-transcribe / t5x /decoding_test.py
juancopi81's picture
Add t5x and mt3 models
b100e1c
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for t5x.decoding."""
import functools
from typing import Mapping, Tuple
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import api
from jax.experimental import host_callback as hcb
import jax.numpy as jnp
import numpy as np
from t5x import decoding
EOS_ID = 1
NEG_INF = decoding.NEG_INF
class DecodeTest(parameterized.TestCase):
def test_temperature_sample_uneven_prefix(self):
def token_to_logits(ids, cache):
del ids
del cache
# Always sample id 2 for batch element 0 and id 3 for element 1.
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32)
return logits, {}
inputs = np.array([[0, 5, 7, 1, 0, 0], [0, 6, 1, 0, 0, 0]])
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=0,
initial_index=np.array([3, 2]))
expected = np.array([[5, 7, 1, 2, 2, 2], [6, 1, 3, 3, 3, 3]])
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_no_prefix(self):
batch, max_decode_len = 2, 3
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Always sample id 2 for batch element 0 and id 3 for element 1.
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32)
return logits, {}
inputs = np.zeros((batch, max_decode_len), dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0)
expected = [[2, 2, 2], [3, 3, 3]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_prefix(self):
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Always sample id 2 for batch element 0 and id 3 for element 1.
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32)
return logits, {}
# batch element 0 has length 3 prefix and element 1 has length 2.
inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0)
expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_with_zero_temperature(self):
batch, max_decode_len = 2, 3
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Use very large logits that are close to one another.
logits = np.array(
[[1700.47, 1700.48, 1700.51, 1700.45], [3.2, 4.8, -5.3, 5.6]],
dtype=np.float32)
return logits, {}
inputs = np.zeros((batch, max_decode_len), dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=4,
temperature=0.0)
expected = [[2, 2, 2], [3, 3, 3]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_prefix_ending_with_eos(self):
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Always sample id 2 for batch element 0 and id 3 for element 1.
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32)
return logits, {}
# batch element 0 has length 4 prefix (including the initial dummy token and
# the last eos) and element 1 has length 3.
inputs = np.array([[0, 5, 6, 1, 0], [0, 8, 1, 0, 0]], dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=1)
expected = [[5, 6, 1, 2, 2], [8, 1, 3, 3, 3]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_with_state_callback(self):
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# A distribution with roughly all probability mass in sample id 3
logits = np.array([[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32)
return logits, {}
def state_callback_fn(state):
i, sequences, cache, cur_token, ended, rng, log_prob = state
def callback_fn(current_index_and_sequences):
"""Add EOS token after first time token id 3 has been sampled."""
current_index, sequences = current_index_and_sequences
sequences = np.array(sequences)
for i in range(len(current_index)):
if sequences[i, current_index[i]] == 3:
sequences[i, current_index[i] + 1] = EOS_ID
return sequences
sequences = hcb.call(
callback_fn, (i, sequences),
result_shape=api.ShapeDtypeStruct(sequences.shape, sequences.dtype))
return i, sequences, cache, cur_token, ended, rng, log_prob
inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=0,
temperature=0.0,
state_callback_fn=state_callback_fn)
expected = [[5, 6, 7, 3, EOS_ID], [8, 9, 3, EOS_ID, 0]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_with_logit_callback(self):
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# uniform distribution over targets from model
logits = np.array([[-1e7, -1e7, -1e7, -1e7], [-1e7, -1e7, -1e7, -1e7]],
dtype=np.float32)
return logits, {}
def logit_callback_fn(logits, state):
del state # unused
# Rewrite logits to always sample id 2 for batch element 0 and
# id 3 for element 1.
logits[0, 2] = 0
logits[1, 3] = 0
return logits
# batch element 0 has length 3 prefix and element 1 has length 2.
inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=0,
temperature=0.0,
logit_callback_fn=logit_callback_fn)
expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_prefix_ending_with_eos_early_stop(self):
batch, max_decode_len = 2, 7
rng0 = jax.random.PRNGKey(0)
ret = [np.array([2, 3]) for _ in range(max_decode_len)]
# Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of
# `decoding._temperature_sample_single_trial`.
ret[3] = np.array([2, 1])
# Sequence 0 outputs EOS=1 when i = 4.
ret[4] = np.array([1, 3])
ret = jax.numpy.array(ret)
def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument
"""Ignores logit and returns only based on the rng_input."""
rng = rng0
k = 0
# Mimic the rng split done in `decoding.sample_loop_body_fn`.
for j in range(max_decode_len):
rng1, rng = jax.random.split(rng)
# We want to sift out `j` for which rng1 == rng_input
# rngs are a pair of ints. So sum the bool and divide by 2.
k += j * (rng1 == rng_input).sum() // 2
# `k` at this point is equal to the while loop variable `i` of the caller.
return ret[k]
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# These values are not used in this test because random.categorical is
# directly mocked.
dummy_logits = np.zeros((batch, 4), dtype=np.float32)
return dummy_logits, {}
inputs = np.array([[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]],
dtype=np.int32)
with mock.patch.object(jax.random, 'categorical', new=mocked_categorical):
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {}, token_to_logits, EOS_ID, rng0, topk=0)
expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_greedy_decoding_topk_sample_log_probs(self):
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Sample [2, 3] with probability [0.6, 0.4].
logits = np.array([[-1e7, -1e7, -0.510825624, -0.916290732]],
dtype=np.float32)
return logits, {}
inputs = np.array([[0, 2, 2, 2, 0]], dtype=np.int32)
sampled_sequences, sampled_log_probs = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=1,
rescale_log_probs=True)
expected_sequence = [[2, 2, 2, 2, 2]]
expected_log_probs = [0.0]
np.testing.assert_array_equal(expected_sequence, sampled_sequences)
np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs)
inputs = np.array([[0, 2, 2, 3, 0]], dtype=np.int32)
sampled_sequences, sampled_log_probs = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=1,
rescale_log_probs=False)
expected_sequence = [[2, 2, 3, 2, 2]]
expected_log_probs = [-1.02165125]
np.testing.assert_array_equal(expected_sequence, sampled_sequences)
np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs)
def test_temperature_sample_log_prob(self):
batch, max_decode_len = 2, 7
rng0 = jax.random.PRNGKey(0)
ret = [np.array([2, 3]) for _ in range(max_decode_len)]
# Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of
# `decoding._temperature_sample_single_trial`.
ret[3] = np.array([2, 1])
# Sequence 0 outputs EOS=1 when i = 4.
ret[4] = np.array([1, 3])
ret = jax.numpy.array(ret)
# TODO(hwchung): refactor this.
def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument
"""Ignores logit and returns only based on the rng_input."""
rng = rng0
k = 0
# Mimic the rng split done in `decoding.sample_loop_body_fn`.
for j in range(max_decode_len):
rng1, rng = jax.random.split(rng)
# We want to sift out `j` for which rng1 == rng_input
# rngs are a pair of ints. So sum the bool and divide by 2.
k += j * (rng1 == rng_input).sum() // 2
# `k` at this point is equal to the while loop variable `i` of the caller.
return ret[k]
logits = np.random.randn(batch, 4)
token_to_logits = lambda ids, cache: (logits, {})
inputs = np.array([[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]],
dtype=np.int32)
with mock.patch.object(jax.random, 'categorical', new=mocked_categorical):
sampled_sequences, log_prob = decoding._temperature_sample_single_trial(
inputs, {}, token_to_logits, EOS_ID, rng0, topk=0)
log_probs = jax.nn.log_softmax(logits)
expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]]
expected_log_prob = [
log_probs[0, 2] + log_probs[0, 2] + log_probs[0, 1],
log_probs[1, 3] + log_probs[1, 3] + log_probs[1, 1]
]
expected_log_prob = np.array(expected_log_prob)
np.testing.assert_array_equal(expected, sampled_sequences)
np.testing.assert_allclose(expected_log_prob, log_prob, atol=1e-5)
def test_temperature_sample_num_decodes(self):
num_decodes = 3
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32)
with mock.patch.object(decoding,
'_temperature_sample_single_trial') as mocked:
# expanded_decodes: [batch * num_decodes, max_decode_len]
expanded_decodes = np.array([[5, 1, 4, 4], [5, 1, 5, 5], [5, 1, 3, 3],
[8, 7, 5, 5], [8, 7, 3, 3], [8, 7, 4, 4]])
# expanded_log_prob: [batch * num_decodes]
expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9])
mocked.return_value = expanded_decodes, expanded_log_prob
decodes, scores = decoding.temperature_sample(
inputs, {}, mock.Mock(), EOS_ID, rng0, num_decodes=num_decodes)
expanded_inputs = jnp.array([[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0],
[0, 8, 7, 0], [0, 8, 7, 0], [0, 8, 7, 0]])
# Test that the actual decode function is called with the expanded values.
np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs)
np.testing.assert_array_equal(decodes,
[[[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]],
[[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]]])
np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]])
def test_temperature_sample_num_decodes_with_initial_index(self):
num_decodes = 3
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32)
initial_index = np.array([1, 2], dtype=np.int32)
with mock.patch.object(decoding,
'_temperature_sample_single_trial') as mocked:
with mock.patch.object(decoding, 'cache_map') as mocked_cache_map:
# expanded_decodes: [batch * num_decodes, max_decode_len]
expanded_decodes = np.array([[5, 1, 4, 4], [5, 1, 5, 5], [5, 1, 3, 3],
[8, 7, 5, 5], [8, 7, 3, 3], [8, 7, 4, 4]])
# expanded_log_prob: [batch * num_decodes]
expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9])
mocked.return_value = expanded_decodes, expanded_log_prob
decodes, scores = decoding.temperature_sample(
inputs, {},
mock.Mock(),
EOS_ID,
rng0,
num_decodes=num_decodes,
initial_index=initial_index)
expanded_inputs = jnp.array([[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0],
[0, 8, 7, 0], [0, 8, 7, 0], [0, 8, 7, 0]])
expanded_initial_index = np.array([1, 1, 1, 2, 2, 2], dtype=np.int32)
# Test that the actual decode function is called with the expanded
# values.
np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs)
np.testing.assert_array_equal(mocked.call_args[1]['initial_index'],
expanded_initial_index)
# Test that the function was applied to the index in the cache map
self.assertTrue(mocked_cache_map.call_args[1]['apply_to_index'])
np.testing.assert_array_equal(decodes,
[[[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]],
[[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]]])
np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]])
@parameterized.named_parameters(
dict(
testcase_name='no_initial_index',
initial_index=None,
expected_calls=6,
),
dict(
testcase_name='initial_index',
initial_index=np.array([1, 2], dtype=np.int32),
expected_calls=4,
),
dict(
testcase_name='lower_initial_index',
initial_index=np.array([1, 1], dtype=np.int32),
expected_calls=5, # we decode 4 tokens out of the prompt
),
)
def test_temperature_sample_max_decode_steps_with_initial_index(
self, initial_index, expected_calls):
max_decode_steps = 4
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]],
dtype=np.int32)
token_to_logits = mock.Mock()
token_to_logits.return_value = (np.array(
[[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {})
# to unroll while loop
with jax.disable_jit():
decodes, scores = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
initial_index=initial_index,
topk=4,
max_decode_steps=max_decode_steps)
self.assertLen(token_to_logits.call_args_list, expected_calls)
expected_output = np.array([[2, 3, 3, 3, 3, 0, 0, 0],
[2, 2, 3, 3, 3, 3, 0, 0]])
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_array_equal(scores, [[0.], [0.]])
def test_temperature_sample_max_decode_steps_endpad(self):
max_decode_steps = 4
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 2, 2, 2, 2, 0],
[0, 2, 2, 2, 0, 0, 0, 0]],
dtype=np.int32)
initial_index = np.array([1, 6, 0])
token_to_logits = mock.Mock()
token_to_logits.return_value = (np.array(
[[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32), {})
# to unroll while loop
with jax.disable_jit():
decodes, scores = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
initial_index=initial_index,
topk=4,
max_decode_steps=max_decode_steps)
# `inputs[2]` starts from index 0. So it requires 3 calls to
# `token_to_logits` to exit the prompt (these generated tokens are
# overridden) and 4 more calls to fill the rest. `inputs[0]` only need 4
# calls. In the last 3 calls, it generates but MUST NOT populate the
# sequences because it is already ended.
self.assertLen(token_to_logits.call_args_list, 7)
expected_output = np.array(
[[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 2, 2, 2, 2, 3, 3],
[2, 2, 2, 3, 3, 3, 3, 0]],
dtype=np.int32)
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_allclose(scores, [[0.], [0.], [0.]])
def test_temperature_sample_max_decode_steps_docstring_ex4(self):
max_decode_steps = 2
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 3, 4, 0, 0, 0, 0, 0]],
dtype=np.int32)
initial_index = np.array([1, 2])
token_to_logits = mock.Mock()
token_to_logits.return_value = (np.array(
[[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {})
# to unroll while loop
with jax.disable_jit():
decodes, _ = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
initial_index=initial_index,
topk=4,
max_decode_steps=max_decode_steps)
self.assertLen(token_to_logits.call_args_list, 2)
expected_output = np.array(
[[2, 2, 2, 0, 0, 0, 0, 0], [3, 4, 3, 3, 0, 0, 0, 0]], dtype=np.int32)
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
def test_temperature_sample_max_decode_steps_hard_limit(self):
max_decode_steps = 10
max_decode_steps_hard_limit = 4
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]],
dtype=np.int32)
token_to_logits = mock.Mock()
token_to_logits.return_value = (np.array(
[[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {})
# to unroll while loop
with jax.disable_jit():
decodes, scores = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
topk=4,
max_decode_steps=max_decode_steps,
max_decode_steps_hard_limit=max_decode_steps_hard_limit)
expected_output = np.array([[2, 3, 3, 3, 3, 0, 0, 0],
[2, 2, 3, 3, 3, 3, 0, 0]])
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_array_equal(scores, [[0.], [0.]])
def test_temperature_sample_topp(self):
rng0 = jax.random.PRNGKey(0)
inputs = np.zeros((1, 20), dtype=np.int32)
token_to_logits = mock.Mock()
# logits correspond to (0.3, 0, 0.1, 0.6)
token_to_logits.return_value = (np.array([[-1.2, -1e7, -2.3, -0.51]],
dtype=np.float32), {})
decodes, scores = decoding.temperature_sample(
inputs, {}, token_to_logits, EOS_ID, rng0, topp=0.55,
topk=0) # anything under 0.6 will trigger deterministic decoding.
expected_output = np.array([[3] * 20])
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_array_equal(scores, [[0.]])
# temperature is applied first, so the distribution becomes
# (0.27, 0, 0.069, 0.65), so if topp is 0.63, it should become greedy.
decodes, scores = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
temperature=0.8,
topp=0.63,
topk=0)
expected_output = np.array([[3] * 20])
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_array_equal(scores, [[0.]])
def test_dynamic_topp_max_decode_steps(self):
rng0 = jax.random.PRNGKey(0)
inputs = np.zeros((1, 20), dtype=np.int32)
token_to_logits = mock.Mock()
# logits correspond to (0.3, 0, 0.1, 0.6)
token_to_logits.return_value = (np.array([[-1.2, -1e7, -2.3, -0.51]],
dtype=np.float32), {})
def dynamic_decode_fn(inputs, temperature, topp, max_decode_steps):
return decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
temperature=temperature,
topp=topp,
topk=0,
max_decode_steps=max_decode_steps)
dynamic_decode_fn_jit = jax.jit(dynamic_decode_fn)
decodes, scores = dynamic_decode_fn_jit(inputs, 0.8, 0.63, 10)
expected_output = np.array([[3] * 10 + [0] * 10])
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_array_equal(scores, [[0.]])
def test_topp_log_probs(self):
rng0 = jax.random.PRNGKey(0)
inputs = np.zeros((1, 1), dtype=np.int32)
token_to_logits = mock.Mock()
# logits correspond to (0.3, 0, 0.1, 0.6)
token_to_logits.return_value = (np.array([[-1.2, NEG_INF, -2.3, -0.51]],
dtype=np.float32), {})
with jax.disable_jit():
# this lets us see logits after topp and topk are applied
with mock.patch.object(jax.random, 'categorical') as mocked:
mocked.return_value = jnp.array([0], dtype=jnp.int32)
decodes, _ = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
temperature=1.4,
topp=0.7,
topk=0)
self.assertLen(token_to_logits.call_args_list, 1)
np.testing.assert_array_equal(decodes, jnp.asarray([[[0]]]))
np.testing.assert_array_almost_equal(
mocked.call_args_list[0][0][1],
jnp.asarray([[-0.85714293, NEG_INF, NEG_INF, -0.36428571]]))
def test_add_beam_dim(self):
x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32)
y = decoding.add_beam_dim(x, beam_size=3)
self.assertEqual(y.shape, (2, 3, 4))
np.testing.assert_array_equal([[[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0]],
[[0, 8, 6, 9], [0, 8, 6, 9], [0, 8, 6, 9]]],
y)
def test_flat_batch_beam_expand(self):
x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32)
np.testing.assert_array_equal(
[[0, 5, 1, 0], [0, 5, 1, 0], [0, 8, 6, 9], [0, 8, 6, 9]],
decoding.flat_batch_beam_expand(x, beam_size=2))
def test_top_k_two_stage(self):
def _test_top_k(batch_size, k):
# Pick sufficiently large seq_len.
seq_len = 2047 * k * batch_size
seq = np.arange(seq_len)
np.random.shuffle(seq)
x = jnp.reshape(seq, (batch_size, int(seq_len / batch_size))).astype(
jnp.float32)
np.testing.assert_almost_equal(
decoding.top_k_two_stage(x, k), jax.lax.top_k(x, k), decimal=5)
# Test small batch cases (batch={1,8}, k=16).
_test_top_k(1, 16)
_test_top_k(8, 16)
# Test large batch cases (batch={9,32}, k=11).
_test_top_k(9, 11)
_test_top_k(32, 11)
def test_cache_map(self):
cache = {
'layers_0': {
'cached_key': jnp.ones([3, 6]),
'cached_values': jnp.ones([3, 6]),
'cache_index': jnp.ones([
3,
]),
},
'layers_1': {
'self_attention': {
'cached_key': jnp.ones([2, 7]),
'cached_values': jnp.ones([5, 8]),
'cache_index': jnp.array(1),
},
'encoder_decoder_attention': {
'cached_key': jnp.ones([10, 12, 2]),
'cached_values': jnp.ones([4, 7, 2]),
'cache_index': jnp.ones([4, 5, 6]),
}
},
}
fn = functools.partial(jnp.add, 4)
gold_cache = {
'layers_0': {
'cached_key': fn(jnp.ones([3, 6])),
'cached_values': fn(jnp.ones([3, 6])),
'cache_index': jnp.ones([
3,
]),
},
'layers_1': {
'self_attention': {
'cached_key': fn(jnp.ones([2, 7])),
'cached_values': fn(jnp.ones([5, 8])),
'cache_index': jnp.array(1),
},
'encoder_decoder_attention': {
'cached_key': fn(jnp.ones([10, 12, 2])),
'cached_values': fn(jnp.ones([4, 7, 2])),
'cache_index': jnp.ones([4, 5, 6]),
}
}
}
jax.tree_multimap(np.testing.assert_array_equal,
decoding.cache_map(fn, cache), gold_cache)
def test_cache_map_with_index(self):
cache = {
'layers_0': {
'cached_key': jnp.ones([3, 6]),
'cached_values': jnp.ones([3, 6]),
'cache_index': jnp.ones([
3,
]),
},
'layers_1': {
'relpos_bias': {
'cached_bias': jnp.ones([1, 5, 3]),
},
'self_attention': {
'cached_key': jnp.ones([2, 7]),
'cached_values': jnp.ones([5, 8]),
'cache_index': jnp.array(1),
},
'encoder_decoder_attention': {
'cached_key': jnp.ones([10, 12, 2]),
'cached_values': jnp.ones([4, 7, 2]),
'cache_index': jnp.ones([4, 5, 6]),
}
},
'position_embedder': {
'position_embedder_index': jnp.array(-1),
},
}
fn = functools.partial(jnp.add, 8)
gold_cache = {
'layers_0': {
'cached_key': fn(jnp.ones([3, 6])),
'cached_values': fn(jnp.ones([3, 6])),
'cache_index': fn(jnp.ones([
3,
])),
},
'layers_1': {
'relpos_bias': {
'cached_bias': jnp.ones([1, 5, 3]),
},
'self_attention': {
'cached_key': fn(jnp.ones([2, 7])),
'cached_values': fn(jnp.ones([5, 8])),
'cache_index': fn(jnp.array(1)),
},
'encoder_decoder_attention': {
'cached_key': fn(jnp.ones([10, 12, 2])),
'cached_values': fn(jnp.ones([4, 7, 2])),
'cache_index': fn(jnp.ones([4, 5, 6])),
}
},
'position_embedder': {
'position_embedder_index': jnp.array(-1),
},
}
jax.tree_multimap(np.testing.assert_array_equal,
decoding.cache_map(fn, cache, apply_to_index=True),
gold_cache)
def test_beam_search(self):
# Toy problem, we have 4 states, A, B, START, END, (plus PAD).
# Scores are given by a first-order Markov model.
batch_size = 2
beam_size = 2
# PAD doesn't matter for this test, but part of the contract for beam_search
# is giving the PAD token id 0.
states = ['PAD', 'A', 'B', 'START-', '-END']
num_states = len(states)
decode_length = 7
# Edge potentials (written inside edges for diagonals):
# 1 -1 1 -1
# A ---- A ---- A ---- A ---- A
# 0 \ -1 \ 1 \ -1 \ 1 0
# START X X X X END
# 0 / -1 / 1 / -1 / 1 0
# B ---- B ---- B ---- B ---- B
# 1 -1 1 -1
# put the above edge potentials in a 3-tensor
ab_edge_potentials = np.asarray([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]],
[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]])
# now we have to add on the START, END states
# and PAD at 0
edge_potentials = np.ones([6, 5, 5]) * NEG_INF
edge_potentials[1:5, 1:3, 1:3] = ab_edge_potentials
# START can go to either A or B for free at t0
edge_potentials[0, 3, 1] = 0
edge_potentials[0, 3, 2] = 0
# either A or B can go to END for free at t5
edge_potentials[5, 1, 4] = 0
edge_potentials[5, 2, 4] = 0
# PAD can go to anything for free (doesn't matter for this test)
edge_potentials[:, 0, :] = 0
edge_potentials = jnp.asarray(edge_potentials)
# at time 0, we start with state=START=3
logits0 = jnp.asarray([NEG_INF, NEG_INF, NEG_INF, 0, NEG_INF])
# add dummy flattened batch x beam dim for broadcasting
logits0 = jnp.expand_dims(logits0, axis=0)
edge_potentials = jnp.expand_dims(edge_potentials, axis=0)
def tokens_to_logits(
token_indices: jnp.ndarray, state_cache: Mapping[str, jnp.ndarray]
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
cur_iter = state_cache['cur_iter']
# grab edge potentials for the current timestep
cur_edge_potentials = jnp.take_along_axis(
edge_potentials,
jnp.reshape(
jnp.maximum(0, cur_iter[:, 0].astype(jnp.int32) - 1),
(batch_size * beam_size, 1, 1, 1)),
axis=1)
cur_edge_potentials = jnp.squeeze(cur_edge_potentials, axis=1)
# get "logits" from edge potentials for requested tokens (except at t0)
cur_logits = jnp.matmul(
jnp.reshape(
jax.nn.one_hot(token_indices, num_states, axis=1),
(batch_size * beam_size, 1, num_states)), cur_edge_potentials)
cur_logits = jnp.squeeze(cur_logits, axis=1)
# use our START-only logits for t0, otherwise use the edge potentials
logits_for_tokens = jnp.where(cur_iter == 0, logits0, cur_logits)
# update state in the cache
new_cache = state_cache.copy()
new_cache['cur_iter'] = cur_iter + 1
return logits_for_tokens, new_cache
init_cache = {}
init_cache['cur_iter'] = jnp.zeros((batch_size, 1))
top_scoring, _ = decoding.beam_search(
inputs=np.zeros([batch_size, decode_length]),
cache=init_cache,
tokens_to_logits=tokens_to_logits,
eos_id=4,
num_decodes=beam_size,
alpha=0.0,
max_decode_len=decode_length)
# The two top scoring sequences should be a tie between
# START-AABBA-END
# and
# START-BBAAB-END
# (and greedy beam search will find both these with just two beams)
top_scoring_strings = [
''.join(states[tok]
for tok in top_scoring[0, i, :])
for i in range(beam_size)
]
expected = ['START-AABBA-END', 'START-BBAAB-END']
np.testing.assert_array_equal(expected, top_scoring_strings)
def test_beam_search_force_decode_prefix(self):
beam_size = 2
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1.
logits = np.repeat(
np.expand_dims(
np.array([[-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4],
[-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4]],
dtype=np.float32),
axis=1), [beam_size],
axis=1)
logits = decoding.flatten_beam_dim(logits)
return logits, {}
# batch element 0 has length 1 and element 1 has length 2.
inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32)
rolled_inputs = np.array([[7, 0, 0, 0, 0], [4, 5, 0, 0, 0]], dtype=np.int32)
beam_search_sequences, decoding_scores = decoding.beam_search(
inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size, alpha=0)
# Prefixes are forced depending on inputs.
# Beam search sequences and corresponding scores are in reverse order.
self.assertTrue(np.all(np.diff(decoding_scores) >= 0))
expected = np.array([[[7, 3, 2, 2, 2], [7, 2, 2, 2, 2]],
[[4, 5, 2, 3, 3], [4, 5, 3, 3, 3]]])
np.testing.assert_array_equal(expected, beam_search_sequences)
expected_scores = []
batch_logits = np.array([[-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4],
[-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4]],
dtype=np.float32)
for batch, logits, prompt in zip(expected, batch_logits, rolled_inputs):
beam_expected_scores = []
for beam in batch:
log_probs = jax.nn.log_softmax(logits)
# Add them directly since they are static.
beam_scores = []
for token, prompt_token in zip(beam, prompt):
if prompt_token != 0:
beam_scores.append(0)
else:
beam_scores.append(log_probs[token])
beam_expected_scores.append(sum(beam_scores))
expected_scores.append(beam_expected_scores)
np.testing.assert_allclose(expected_scores, decoding_scores, atol=1e-5)
def test_beam_search_force_decode_no_prefix(self):
beam_size = 2
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1.
logits = np.repeat(
np.expand_dims(
np.array([[-1e7, -1e10, -0.1, -0.9], [-1e7, -1e10, -0.9, -0.1]],
dtype=np.float32),
axis=1), [beam_size],
axis=1)
logits = decoding.flatten_beam_dim(logits)
return logits, {}
# No prefix is passed.
inputs = np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=np.int32)
beam_search_sequences, decoding_scores = decoding.beam_search(
inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size)
# Prefixes are forced depending on inputs.
# Beam search sequences and corresponding scores are in reverse order.
self.assertTrue(np.all(np.diff(decoding_scores) >= 0))
expected = np.array([[[3, 2, 2, 2, 2], [2, 2, 2, 2, 2]],
[[2, 3, 3, 3, 3], [3, 3, 3, 3, 3]]])
np.testing.assert_array_equal(expected, beam_search_sequences)
if __name__ == '__main__':
absltest.main()