# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for t5x.decoding.""" import functools from typing import Mapping, Tuple from unittest import mock from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import api from jax.experimental import host_callback as hcb import jax.numpy as jnp import numpy as np from t5x import decoding EOS_ID = 1 NEG_INF = decoding.NEG_INF class DecodeTest(parameterized.TestCase): def test_temperature_sample_uneven_prefix(self): def token_to_logits(ids, cache): del ids del cache # Always sample id 2 for batch element 0 and id 3 for element 1. logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32) return logits, {} inputs = np.array([[0, 5, 7, 1, 0, 0], [0, 6, 1, 0, 0, 0]]) sampled_sequences, _ = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0, initial_index=np.array([3, 2])) expected = np.array([[5, 7, 1, 2, 2, 2], [6, 1, 3, 3, 3, 3]]) np.testing.assert_array_equal(expected, sampled_sequences) def test_temperature_sample_no_prefix(self): batch, max_decode_len = 2, 3 def token_to_logits(ids, cache): # pylint: disable=unused-argument # Always sample id 2 for batch element 0 and id 3 for element 1. logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32) return logits, {} inputs = np.zeros((batch, max_decode_len), dtype=np.int32) sampled_sequences, _ = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0) expected = [[2, 2, 2], [3, 3, 3]] np.testing.assert_array_equal(expected, sampled_sequences) def test_temperature_sample_prefix(self): def token_to_logits(ids, cache): # pylint: disable=unused-argument # Always sample id 2 for batch element 0 and id 3 for element 1. logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32) return logits, {} # batch element 0 has length 3 prefix and element 1 has length 2. inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) sampled_sequences, _ = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0) expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]] np.testing.assert_array_equal(expected, sampled_sequences) def test_temperature_sample_with_zero_temperature(self): batch, max_decode_len = 2, 3 def token_to_logits(ids, cache): # pylint: disable=unused-argument # Use very large logits that are close to one another. logits = np.array( [[1700.47, 1700.48, 1700.51, 1700.45], [3.2, 4.8, -5.3, 5.6]], dtype=np.float32) return logits, {} inputs = np.zeros((batch, max_decode_len), dtype=np.int32) sampled_sequences, _ = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=4, temperature=0.0) expected = [[2, 2, 2], [3, 3, 3]] np.testing.assert_array_equal(expected, sampled_sequences) def test_temperature_sample_prefix_ending_with_eos(self): def token_to_logits(ids, cache): # pylint: disable=unused-argument # Always sample id 2 for batch element 0 and id 3 for element 1. logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32) return logits, {} # batch element 0 has length 4 prefix (including the initial dummy token and # the last eos) and element 1 has length 3. inputs = np.array([[0, 5, 6, 1, 0], [0, 8, 1, 0, 0]], dtype=np.int32) sampled_sequences, _ = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=1) expected = [[5, 6, 1, 2, 2], [8, 1, 3, 3, 3]] np.testing.assert_array_equal(expected, sampled_sequences) def test_temperature_sample_with_state_callback(self): def token_to_logits(ids, cache): # pylint: disable=unused-argument # A distribution with roughly all probability mass in sample id 3 logits = np.array([[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32) return logits, {} def state_callback_fn(state): i, sequences, cache, cur_token, ended, rng, log_prob = state def callback_fn(current_index_and_sequences): """Add EOS token after first time token id 3 has been sampled.""" current_index, sequences = current_index_and_sequences sequences = np.array(sequences) for i in range(len(current_index)): if sequences[i, current_index[i]] == 3: sequences[i, current_index[i] + 1] = EOS_ID return sequences sequences = hcb.call( callback_fn, (i, sequences), result_shape=api.ShapeDtypeStruct(sequences.shape, sequences.dtype)) return i, sequences, cache, cur_token, ended, rng, log_prob inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) sampled_sequences, _ = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0, temperature=0.0, state_callback_fn=state_callback_fn) expected = [[5, 6, 7, 3, EOS_ID], [8, 9, 3, EOS_ID, 0]] np.testing.assert_array_equal(expected, sampled_sequences) def test_temperature_sample_with_logit_callback(self): def token_to_logits(ids, cache): # pylint: disable=unused-argument # uniform distribution over targets from model logits = np.array([[-1e7, -1e7, -1e7, -1e7], [-1e7, -1e7, -1e7, -1e7]], dtype=np.float32) return logits, {} def logit_callback_fn(logits, state): del state # unused # Rewrite logits to always sample id 2 for batch element 0 and # id 3 for element 1. logits[0, 2] = 0 logits[1, 3] = 0 return logits # batch element 0 has length 3 prefix and element 1 has length 2. inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) sampled_sequences, _ = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0, temperature=0.0, logit_callback_fn=logit_callback_fn) expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]] np.testing.assert_array_equal(expected, sampled_sequences) def test_temperature_sample_prefix_ending_with_eos_early_stop(self): batch, max_decode_len = 2, 7 rng0 = jax.random.PRNGKey(0) ret = [np.array([2, 3]) for _ in range(max_decode_len)] # Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of # `decoding._temperature_sample_single_trial`. ret[3] = np.array([2, 1]) # Sequence 0 outputs EOS=1 when i = 4. ret[4] = np.array([1, 3]) ret = jax.numpy.array(ret) def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument """Ignores logit and returns only based on the rng_input.""" rng = rng0 k = 0 # Mimic the rng split done in `decoding.sample_loop_body_fn`. for j in range(max_decode_len): rng1, rng = jax.random.split(rng) # We want to sift out `j` for which rng1 == rng_input # rngs are a pair of ints. So sum the bool and divide by 2. k += j * (rng1 == rng_input).sum() // 2 # `k` at this point is equal to the while loop variable `i` of the caller. return ret[k] def token_to_logits(ids, cache): # pylint: disable=unused-argument # These values are not used in this test because random.categorical is # directly mocked. dummy_logits = np.zeros((batch, 4), dtype=np.float32) return dummy_logits, {} inputs = np.array([[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]], dtype=np.int32) with mock.patch.object(jax.random, 'categorical', new=mocked_categorical): sampled_sequences, _ = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, rng0, topk=0) expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]] np.testing.assert_array_equal(expected, sampled_sequences) def test_greedy_decoding_topk_sample_log_probs(self): def token_to_logits(ids, cache): # pylint: disable=unused-argument # Sample [2, 3] with probability [0.6, 0.4]. logits = np.array([[-1e7, -1e7, -0.510825624, -0.916290732]], dtype=np.float32) return logits, {} inputs = np.array([[0, 2, 2, 2, 0]], dtype=np.int32) sampled_sequences, sampled_log_probs = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=1, rescale_log_probs=True) expected_sequence = [[2, 2, 2, 2, 2]] expected_log_probs = [0.0] np.testing.assert_array_equal(expected_sequence, sampled_sequences) np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs) inputs = np.array([[0, 2, 2, 3, 0]], dtype=np.int32) sampled_sequences, sampled_log_probs = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=1, rescale_log_probs=False) expected_sequence = [[2, 2, 3, 2, 2]] expected_log_probs = [-1.02165125] np.testing.assert_array_equal(expected_sequence, sampled_sequences) np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs) def test_temperature_sample_log_prob(self): batch, max_decode_len = 2, 7 rng0 = jax.random.PRNGKey(0) ret = [np.array([2, 3]) for _ in range(max_decode_len)] # Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of # `decoding._temperature_sample_single_trial`. ret[3] = np.array([2, 1]) # Sequence 0 outputs EOS=1 when i = 4. ret[4] = np.array([1, 3]) ret = jax.numpy.array(ret) # TODO(hwchung): refactor this. def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument """Ignores logit and returns only based on the rng_input.""" rng = rng0 k = 0 # Mimic the rng split done in `decoding.sample_loop_body_fn`. for j in range(max_decode_len): rng1, rng = jax.random.split(rng) # We want to sift out `j` for which rng1 == rng_input # rngs are a pair of ints. So sum the bool and divide by 2. k += j * (rng1 == rng_input).sum() // 2 # `k` at this point is equal to the while loop variable `i` of the caller. return ret[k] logits = np.random.randn(batch, 4) token_to_logits = lambda ids, cache: (logits, {}) inputs = np.array([[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]], dtype=np.int32) with mock.patch.object(jax.random, 'categorical', new=mocked_categorical): sampled_sequences, log_prob = decoding._temperature_sample_single_trial( inputs, {}, token_to_logits, EOS_ID, rng0, topk=0) log_probs = jax.nn.log_softmax(logits) expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]] expected_log_prob = [ log_probs[0, 2] + log_probs[0, 2] + log_probs[0, 1], log_probs[1, 3] + log_probs[1, 3] + log_probs[1, 1] ] expected_log_prob = np.array(expected_log_prob) np.testing.assert_array_equal(expected, sampled_sequences) np.testing.assert_allclose(expected_log_prob, log_prob, atol=1e-5) def test_temperature_sample_num_decodes(self): num_decodes = 3 rng0 = jax.random.PRNGKey(0) inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32) with mock.patch.object(decoding, '_temperature_sample_single_trial') as mocked: # expanded_decodes: [batch * num_decodes, max_decode_len] expanded_decodes = np.array([[5, 1, 4, 4], [5, 1, 5, 5], [5, 1, 3, 3], [8, 7, 5, 5], [8, 7, 3, 3], [8, 7, 4, 4]]) # expanded_log_prob: [batch * num_decodes] expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9]) mocked.return_value = expanded_decodes, expanded_log_prob decodes, scores = decoding.temperature_sample( inputs, {}, mock.Mock(), EOS_ID, rng0, num_decodes=num_decodes) expanded_inputs = jnp.array([[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0], [0, 8, 7, 0], [0, 8, 7, 0], [0, 8, 7, 0]]) # Test that the actual decode function is called with the expanded values. np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs) np.testing.assert_array_equal(decodes, [[[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]], [[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]]]) np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]]) def test_temperature_sample_num_decodes_with_initial_index(self): num_decodes = 3 rng0 = jax.random.PRNGKey(0) inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32) initial_index = np.array([1, 2], dtype=np.int32) with mock.patch.object(decoding, '_temperature_sample_single_trial') as mocked: with mock.patch.object(decoding, 'cache_map') as mocked_cache_map: # expanded_decodes: [batch * num_decodes, max_decode_len] expanded_decodes = np.array([[5, 1, 4, 4], [5, 1, 5, 5], [5, 1, 3, 3], [8, 7, 5, 5], [8, 7, 3, 3], [8, 7, 4, 4]]) # expanded_log_prob: [batch * num_decodes] expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9]) mocked.return_value = expanded_decodes, expanded_log_prob decodes, scores = decoding.temperature_sample( inputs, {}, mock.Mock(), EOS_ID, rng0, num_decodes=num_decodes, initial_index=initial_index) expanded_inputs = jnp.array([[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0], [0, 8, 7, 0], [0, 8, 7, 0], [0, 8, 7, 0]]) expanded_initial_index = np.array([1, 1, 1, 2, 2, 2], dtype=np.int32) # Test that the actual decode function is called with the expanded # values. np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs) np.testing.assert_array_equal(mocked.call_args[1]['initial_index'], expanded_initial_index) # Test that the function was applied to the index in the cache map self.assertTrue(mocked_cache_map.call_args[1]['apply_to_index']) np.testing.assert_array_equal(decodes, [[[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]], [[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]]]) np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]]) @parameterized.named_parameters( dict( testcase_name='no_initial_index', initial_index=None, expected_calls=6, ), dict( testcase_name='initial_index', initial_index=np.array([1, 2], dtype=np.int32), expected_calls=4, ), dict( testcase_name='lower_initial_index', initial_index=np.array([1, 1], dtype=np.int32), expected_calls=5, # we decode 4 tokens out of the prompt ), ) def test_temperature_sample_max_decode_steps_with_initial_index( self, initial_index, expected_calls): max_decode_steps = 4 rng0 = jax.random.PRNGKey(0) inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]], dtype=np.int32) token_to_logits = mock.Mock() token_to_logits.return_value = (np.array( [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {}) # to unroll while loop with jax.disable_jit(): decodes, scores = decoding.temperature_sample( inputs, {}, token_to_logits, EOS_ID, rng0, initial_index=initial_index, topk=4, max_decode_steps=max_decode_steps) self.assertLen(token_to_logits.call_args_list, expected_calls) expected_output = np.array([[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 3, 3, 3, 3, 0, 0]]) expected_output = jnp.expand_dims(expected_output, 1) np.testing.assert_array_equal(decodes, expected_output) np.testing.assert_array_equal(scores, [[0.], [0.]]) def test_temperature_sample_max_decode_steps_endpad(self): max_decode_steps = 4 rng0 = jax.random.PRNGKey(0) inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 2, 2, 2, 2, 0], [0, 2, 2, 2, 0, 0, 0, 0]], dtype=np.int32) initial_index = np.array([1, 6, 0]) token_to_logits = mock.Mock() token_to_logits.return_value = (np.array( [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {}) # to unroll while loop with jax.disable_jit(): decodes, scores = decoding.temperature_sample( inputs, {}, token_to_logits, EOS_ID, rng0, initial_index=initial_index, topk=4, max_decode_steps=max_decode_steps) # `inputs[2]` starts from index 0. So it requires 3 calls to # `token_to_logits` to exit the prompt (these generated tokens are # overridden) and 4 more calls to fill the rest. `inputs[0]` only need 4 # calls. In the last 3 calls, it generates but MUST NOT populate the # sequences because it is already ended. self.assertLen(token_to_logits.call_args_list, 7) expected_output = np.array( [[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 2, 2, 2, 2, 3, 3], [2, 2, 2, 3, 3, 3, 3, 0]], dtype=np.int32) expected_output = jnp.expand_dims(expected_output, 1) np.testing.assert_array_equal(decodes, expected_output) np.testing.assert_allclose(scores, [[0.], [0.], [0.]]) def test_temperature_sample_max_decode_steps_docstring_ex4(self): max_decode_steps = 2 rng0 = jax.random.PRNGKey(0) inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 3, 4, 0, 0, 0, 0, 0]], dtype=np.int32) initial_index = np.array([1, 2]) token_to_logits = mock.Mock() token_to_logits.return_value = (np.array( [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {}) # to unroll while loop with jax.disable_jit(): decodes, _ = decoding.temperature_sample( inputs, {}, token_to_logits, EOS_ID, rng0, initial_index=initial_index, topk=4, max_decode_steps=max_decode_steps) self.assertLen(token_to_logits.call_args_list, 2) expected_output = np.array( [[2, 2, 2, 0, 0, 0, 0, 0], [3, 4, 3, 3, 0, 0, 0, 0]], dtype=np.int32) expected_output = jnp.expand_dims(expected_output, 1) np.testing.assert_array_equal(decodes, expected_output) def test_temperature_sample_max_decode_steps_hard_limit(self): max_decode_steps = 10 max_decode_steps_hard_limit = 4 rng0 = jax.random.PRNGKey(0) inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]], dtype=np.int32) token_to_logits = mock.Mock() token_to_logits.return_value = (np.array( [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {}) # to unroll while loop with jax.disable_jit(): decodes, scores = decoding.temperature_sample( inputs, {}, token_to_logits, EOS_ID, rng0, topk=4, max_decode_steps=max_decode_steps, max_decode_steps_hard_limit=max_decode_steps_hard_limit) expected_output = np.array([[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 3, 3, 3, 3, 0, 0]]) expected_output = jnp.expand_dims(expected_output, 1) np.testing.assert_array_equal(decodes, expected_output) np.testing.assert_array_equal(scores, [[0.], [0.]]) def test_temperature_sample_topp(self): rng0 = jax.random.PRNGKey(0) inputs = np.zeros((1, 20), dtype=np.int32) token_to_logits = mock.Mock() # logits correspond to (0.3, 0, 0.1, 0.6) token_to_logits.return_value = (np.array([[-1.2, -1e7, -2.3, -0.51]], dtype=np.float32), {}) decodes, scores = decoding.temperature_sample( inputs, {}, token_to_logits, EOS_ID, rng0, topp=0.55, topk=0) # anything under 0.6 will trigger deterministic decoding. expected_output = np.array([[3] * 20]) expected_output = jnp.expand_dims(expected_output, 1) np.testing.assert_array_equal(decodes, expected_output) np.testing.assert_array_equal(scores, [[0.]]) # temperature is applied first, so the distribution becomes # (0.27, 0, 0.069, 0.65), so if topp is 0.63, it should become greedy. decodes, scores = decoding.temperature_sample( inputs, {}, token_to_logits, EOS_ID, rng0, temperature=0.8, topp=0.63, topk=0) expected_output = np.array([[3] * 20]) expected_output = jnp.expand_dims(expected_output, 1) np.testing.assert_array_equal(decodes, expected_output) np.testing.assert_array_equal(scores, [[0.]]) def test_dynamic_topp_max_decode_steps(self): rng0 = jax.random.PRNGKey(0) inputs = np.zeros((1, 20), dtype=np.int32) token_to_logits = mock.Mock() # logits correspond to (0.3, 0, 0.1, 0.6) token_to_logits.return_value = (np.array([[-1.2, -1e7, -2.3, -0.51]], dtype=np.float32), {}) def dynamic_decode_fn(inputs, temperature, topp, max_decode_steps): return decoding.temperature_sample( inputs, {}, token_to_logits, EOS_ID, rng0, temperature=temperature, topp=topp, topk=0, max_decode_steps=max_decode_steps) dynamic_decode_fn_jit = jax.jit(dynamic_decode_fn) decodes, scores = dynamic_decode_fn_jit(inputs, 0.8, 0.63, 10) expected_output = np.array([[3] * 10 + [0] * 10]) expected_output = jnp.expand_dims(expected_output, 1) np.testing.assert_array_equal(decodes, expected_output) np.testing.assert_array_equal(scores, [[0.]]) def test_topp_log_probs(self): rng0 = jax.random.PRNGKey(0) inputs = np.zeros((1, 1), dtype=np.int32) token_to_logits = mock.Mock() # logits correspond to (0.3, 0, 0.1, 0.6) token_to_logits.return_value = (np.array([[-1.2, NEG_INF, -2.3, -0.51]], dtype=np.float32), {}) with jax.disable_jit(): # this lets us see logits after topp and topk are applied with mock.patch.object(jax.random, 'categorical') as mocked: mocked.return_value = jnp.array([0], dtype=jnp.int32) decodes, _ = decoding.temperature_sample( inputs, {}, token_to_logits, EOS_ID, rng0, temperature=1.4, topp=0.7, topk=0) self.assertLen(token_to_logits.call_args_list, 1) np.testing.assert_array_equal(decodes, jnp.asarray([[[0]]])) np.testing.assert_array_almost_equal( mocked.call_args_list[0][0][1], jnp.asarray([[-0.85714293, NEG_INF, NEG_INF, -0.36428571]])) def test_add_beam_dim(self): x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32) y = decoding.add_beam_dim(x, beam_size=3) self.assertEqual(y.shape, (2, 3, 4)) np.testing.assert_array_equal([[[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0]], [[0, 8, 6, 9], [0, 8, 6, 9], [0, 8, 6, 9]]], y) def test_flat_batch_beam_expand(self): x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32) np.testing.assert_array_equal( [[0, 5, 1, 0], [0, 5, 1, 0], [0, 8, 6, 9], [0, 8, 6, 9]], decoding.flat_batch_beam_expand(x, beam_size=2)) def test_top_k_two_stage(self): def _test_top_k(batch_size, k): # Pick sufficiently large seq_len. seq_len = 2047 * k * batch_size seq = np.arange(seq_len) np.random.shuffle(seq) x = jnp.reshape(seq, (batch_size, int(seq_len / batch_size))).astype( jnp.float32) np.testing.assert_almost_equal( decoding.top_k_two_stage(x, k), jax.lax.top_k(x, k), decimal=5) # Test small batch cases (batch={1,8}, k=16). _test_top_k(1, 16) _test_top_k(8, 16) # Test large batch cases (batch={9,32}, k=11). _test_top_k(9, 11) _test_top_k(32, 11) def test_cache_map(self): cache = { 'layers_0': { 'cached_key': jnp.ones([3, 6]), 'cached_values': jnp.ones([3, 6]), 'cache_index': jnp.ones([ 3, ]), }, 'layers_1': { 'self_attention': { 'cached_key': jnp.ones([2, 7]), 'cached_values': jnp.ones([5, 8]), 'cache_index': jnp.array(1), }, 'encoder_decoder_attention': { 'cached_key': jnp.ones([10, 12, 2]), 'cached_values': jnp.ones([4, 7, 2]), 'cache_index': jnp.ones([4, 5, 6]), } }, } fn = functools.partial(jnp.add, 4) gold_cache = { 'layers_0': { 'cached_key': fn(jnp.ones([3, 6])), 'cached_values': fn(jnp.ones([3, 6])), 'cache_index': jnp.ones([ 3, ]), }, 'layers_1': { 'self_attention': { 'cached_key': fn(jnp.ones([2, 7])), 'cached_values': fn(jnp.ones([5, 8])), 'cache_index': jnp.array(1), }, 'encoder_decoder_attention': { 'cached_key': fn(jnp.ones([10, 12, 2])), 'cached_values': fn(jnp.ones([4, 7, 2])), 'cache_index': jnp.ones([4, 5, 6]), } } } jax.tree_multimap(np.testing.assert_array_equal, decoding.cache_map(fn, cache), gold_cache) def test_cache_map_with_index(self): cache = { 'layers_0': { 'cached_key': jnp.ones([3, 6]), 'cached_values': jnp.ones([3, 6]), 'cache_index': jnp.ones([ 3, ]), }, 'layers_1': { 'relpos_bias': { 'cached_bias': jnp.ones([1, 5, 3]), }, 'self_attention': { 'cached_key': jnp.ones([2, 7]), 'cached_values': jnp.ones([5, 8]), 'cache_index': jnp.array(1), }, 'encoder_decoder_attention': { 'cached_key': jnp.ones([10, 12, 2]), 'cached_values': jnp.ones([4, 7, 2]), 'cache_index': jnp.ones([4, 5, 6]), } }, 'position_embedder': { 'position_embedder_index': jnp.array(-1), }, } fn = functools.partial(jnp.add, 8) gold_cache = { 'layers_0': { 'cached_key': fn(jnp.ones([3, 6])), 'cached_values': fn(jnp.ones([3, 6])), 'cache_index': fn(jnp.ones([ 3, ])), }, 'layers_1': { 'relpos_bias': { 'cached_bias': jnp.ones([1, 5, 3]), }, 'self_attention': { 'cached_key': fn(jnp.ones([2, 7])), 'cached_values': fn(jnp.ones([5, 8])), 'cache_index': fn(jnp.array(1)), }, 'encoder_decoder_attention': { 'cached_key': fn(jnp.ones([10, 12, 2])), 'cached_values': fn(jnp.ones([4, 7, 2])), 'cache_index': fn(jnp.ones([4, 5, 6])), } }, 'position_embedder': { 'position_embedder_index': jnp.array(-1), }, } jax.tree_multimap(np.testing.assert_array_equal, decoding.cache_map(fn, cache, apply_to_index=True), gold_cache) def test_beam_search(self): # Toy problem, we have 4 states, A, B, START, END, (plus PAD). # Scores are given by a first-order Markov model. batch_size = 2 beam_size = 2 # PAD doesn't matter for this test, but part of the contract for beam_search # is giving the PAD token id 0. states = ['PAD', 'A', 'B', 'START-', '-END'] num_states = len(states) decode_length = 7 # Edge potentials (written inside edges for diagonals): # 1 -1 1 -1 # A ---- A ---- A ---- A ---- A # 0 \ -1 \ 1 \ -1 \ 1 0 # START X X X X END # 0 / -1 / 1 / -1 / 1 0 # B ---- B ---- B ---- B ---- B # 1 -1 1 -1 # put the above edge potentials in a 3-tensor ab_edge_potentials = np.asarray([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]], [[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) # now we have to add on the START, END states # and PAD at 0 edge_potentials = np.ones([6, 5, 5]) * NEG_INF edge_potentials[1:5, 1:3, 1:3] = ab_edge_potentials # START can go to either A or B for free at t0 edge_potentials[0, 3, 1] = 0 edge_potentials[0, 3, 2] = 0 # either A or B can go to END for free at t5 edge_potentials[5, 1, 4] = 0 edge_potentials[5, 2, 4] = 0 # PAD can go to anything for free (doesn't matter for this test) edge_potentials[:, 0, :] = 0 edge_potentials = jnp.asarray(edge_potentials) # at time 0, we start with state=START=3 logits0 = jnp.asarray([NEG_INF, NEG_INF, NEG_INF, 0, NEG_INF]) # add dummy flattened batch x beam dim for broadcasting logits0 = jnp.expand_dims(logits0, axis=0) edge_potentials = jnp.expand_dims(edge_potentials, axis=0) def tokens_to_logits( token_indices: jnp.ndarray, state_cache: Mapping[str, jnp.ndarray] ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: cur_iter = state_cache['cur_iter'] # grab edge potentials for the current timestep cur_edge_potentials = jnp.take_along_axis( edge_potentials, jnp.reshape( jnp.maximum(0, cur_iter[:, 0].astype(jnp.int32) - 1), (batch_size * beam_size, 1, 1, 1)), axis=1) cur_edge_potentials = jnp.squeeze(cur_edge_potentials, axis=1) # get "logits" from edge potentials for requested tokens (except at t0) cur_logits = jnp.matmul( jnp.reshape( jax.nn.one_hot(token_indices, num_states, axis=1), (batch_size * beam_size, 1, num_states)), cur_edge_potentials) cur_logits = jnp.squeeze(cur_logits, axis=1) # use our START-only logits for t0, otherwise use the edge potentials logits_for_tokens = jnp.where(cur_iter == 0, logits0, cur_logits) # update state in the cache new_cache = state_cache.copy() new_cache['cur_iter'] = cur_iter + 1 return logits_for_tokens, new_cache init_cache = {} init_cache['cur_iter'] = jnp.zeros((batch_size, 1)) top_scoring, _ = decoding.beam_search( inputs=np.zeros([batch_size, decode_length]), cache=init_cache, tokens_to_logits=tokens_to_logits, eos_id=4, num_decodes=beam_size, alpha=0.0, max_decode_len=decode_length) # The two top scoring sequences should be a tie between # START-AABBA-END # and # START-BBAAB-END # (and greedy beam search will find both these with just two beams) top_scoring_strings = [ ''.join(states[tok] for tok in top_scoring[0, i, :]) for i in range(beam_size) ] expected = ['START-AABBA-END', 'START-BBAAB-END'] np.testing.assert_array_equal(expected, top_scoring_strings) def test_beam_search_force_decode_prefix(self): beam_size = 2 def token_to_logits(ids, cache): # pylint: disable=unused-argument # Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. logits = np.repeat( np.expand_dims( np.array([[-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], [-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4]], dtype=np.float32), axis=1), [beam_size], axis=1) logits = decoding.flatten_beam_dim(logits) return logits, {} # batch element 0 has length 1 and element 1 has length 2. inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32) rolled_inputs = np.array([[7, 0, 0, 0, 0], [4, 5, 0, 0, 0]], dtype=np.int32) beam_search_sequences, decoding_scores = decoding.beam_search( inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size, alpha=0) # Prefixes are forced depending on inputs. # Beam search sequences and corresponding scores are in reverse order. self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) expected = np.array([[[7, 3, 2, 2, 2], [7, 2, 2, 2, 2]], [[4, 5, 2, 3, 3], [4, 5, 3, 3, 3]]]) np.testing.assert_array_equal(expected, beam_search_sequences) expected_scores = [] batch_logits = np.array([[-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], [-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4]], dtype=np.float32) for batch, logits, prompt in zip(expected, batch_logits, rolled_inputs): beam_expected_scores = [] for beam in batch: log_probs = jax.nn.log_softmax(logits) # Add them directly since they are static. beam_scores = [] for token, prompt_token in zip(beam, prompt): if prompt_token != 0: beam_scores.append(0) else: beam_scores.append(log_probs[token]) beam_expected_scores.append(sum(beam_scores)) expected_scores.append(beam_expected_scores) np.testing.assert_allclose(expected_scores, decoding_scores, atol=1e-5) def test_beam_search_force_decode_no_prefix(self): beam_size = 2 def token_to_logits(ids, cache): # pylint: disable=unused-argument # Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. logits = np.repeat( np.expand_dims( np.array([[-1e7, -1e10, -0.1, -0.9], [-1e7, -1e10, -0.9, -0.1]], dtype=np.float32), axis=1), [beam_size], axis=1) logits = decoding.flatten_beam_dim(logits) return logits, {} # No prefix is passed. inputs = np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=np.int32) beam_search_sequences, decoding_scores = decoding.beam_search( inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size) # Prefixes are forced depending on inputs. # Beam search sequences and corresponding scores are in reverse order. self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) expected = np.array([[[3, 2, 2, 2, 2], [2, 2, 2, 2, 2]], [[2, 3, 3, 3, 3], [3, 3, 3, 3, 3]]]) np.testing.assert_array_equal(expected, beam_search_sequences) if __name__ == '__main__': absltest.main()