# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for t5x.models.""" import functools from unittest import mock from absl import logging from absl.testing import absltest from absl.testing import parameterized import flax from flax import traverse_util import jax import jax.numpy as jnp import numpy as np import t5.data.tasks # pylint:disable=unused-import from t5x import decoding from t5x import models from t5x import partitioning from t5x import test_utils from t5x import trainer as trainer_lib from t5x import utils import tensorflow as tf # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() PartitionSpec = partitioning.PartitionSpec class ModelsTest(parameterized.TestCase): def test_remove_prefix(self): sequences = np.array([[1, 2, 3, 4, 5, 6, 7, 0], [6, 7, 8, 9, 10, 11, 0, 0]]) prefix_lengths = np.array([2, 4]) expected = [[3, 4, 5, 6, 7, 0, 0, 0], [10, 11, 0, 0, 0, 0, 0, 0]] remove_prefix = jax.jit(models.remove_prefix) actual = remove_prefix(sequences, prefix_lengths) np.testing.assert_array_equal(actual, expected) def test_remove_prefix_zero_len_prefix(self): sequences = np.array([[1, 2, 3, 4, 5, 6, 7, 0], [6, 7, 8, 9, 10, 11, 0, 0]]) prefix_lengths = np.array([0, 0]) remove_prefix = jax.jit(models.remove_prefix) actual = remove_prefix(sequences, prefix_lengths) # The expected output is the original sequences. np.testing.assert_array_equal(actual, sequences) BATCH_SIZE, ENCODER_LEN, MAX_DECODE_LEN, EMBED_DIM = 2, 3, 4, 5 class EncoderDecoderModelTest(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name='no_types', shapes={ 'encoder_input_tokens': [1, 512], 'decoder_input_tokens': [1, 62] }, types=None), dict( testcase_name='int32', shapes={ 'encoder_input_tokens': [1, 512], 'decoder_input_tokens': [1, 62] }, types={ 'encoder_input_tokens': jnp.int32, 'decoder_input_tokens': jnp.int32 }), dict( testcase_name='float32', shapes={ 'encoder_input_tokens': [1, 512], 'decoder_input_tokens': [1, 62], 'encoder_positions': [1, 512], 'decoder_positions': [1, 62], }, types={ 'encoder_input_tokens': jnp.int32, 'decoder_input_tokens': jnp.int32, 'encoder_positions': jnp.int32, 'decoder_positions': jnp.int32 }), dict( testcase_name='float32_segment_ids', shapes={ 'encoder_input_tokens': [1, 512], 'decoder_input_tokens': [1, 62], 'encoder_segment_ids': [1, 512], 'decoder_segment_ids': [1, 62], }, types={ 'encoder_input_tokens': jnp.int32, 'decoder_input_tokens': jnp.int32, 'encoder_segment_ids': jnp.int32, 'decoder_segment_ids': jnp.int32 }), ) def test_get_initial_variables_shapes_and_types(self, shapes, types): mock_transformer = mock.Mock() mock_transformer.init.return_value = {'params': {}} mock_optimizer_def = mock.Mock() rng = mock.Mock() def mock_init(self): self.module = mock_transformer self.optimizer_def = mock_optimizer_def with mock.patch.object( models.EncoderDecoderModel, '__init__', new=mock_init): model = models.EncoderDecoderModel() model.get_initial_variables(rng, shapes, types) if types is None: encoder_input = jnp.ones( shapes['encoder_input_tokens'], dtype=jnp.float32) decoder_input = jnp.ones( shapes['decoder_input_tokens'], dtype=jnp.float32) else: encoder_input = jnp.ones( shapes['encoder_input_tokens'], dtype=types['encoder_input_tokens']) decoder_input = jnp.ones( shapes['decoder_input_tokens'], dtype=types['decoder_input_tokens']) # Using `.assert_called_once_with` doesn't work because the simple # comparison it does for the array arguments fail (truth value of an array # is ambiguous). called_with = mock_transformer.init.call_args self.assertEqual(called_with[0][0], rng) np.testing.assert_allclose(called_with[0][1], encoder_input) np.testing.assert_allclose(called_with[0][2], decoder_input) np.testing.assert_allclose(called_with[0][3], decoder_input) if 'encoder_positions' in shapes: encoder_positions = jnp.ones( shapes['encoder_positions'], dtype=types['encoder_positions']) np.testing.assert_allclose(called_with[1]['encoder_positions'], encoder_positions) else: self.assertIsNone(called_with[1]['encoder_positions']) if 'decoder_positions' in shapes: decoder_positions = jnp.ones( shapes['decoder_positions'], dtype=types['decoder_positions']) np.testing.assert_allclose(called_with[1]['decoder_positions'], decoder_positions) else: self.assertIsNone(called_with[1]['decoder_positions']) if 'encoder_segment_ids' in shapes: encoder_positions = jnp.ones( shapes['encoder_segment_ids'], dtype=types['encoder_segment_ids']) np.testing.assert_allclose(called_with[1]['encoder_segment_ids'], encoder_positions) else: self.assertIsNone(called_with[1]['encoder_segment_ids']) if 'decoder_segment_ids' in shapes: decoder_segment_ids = jnp.ones( shapes['decoder_segment_ids'], dtype=types['decoder_segment_ids']) np.testing.assert_allclose(called_with[1]['decoder_segment_ids'], decoder_segment_ids) else: self.assertIsNone(called_with[1]['decoder_segment_ids']) self.assertFalse(called_with[1]['decode']) self.assertFalse(called_with[1]['enable_dropout']) @parameterized.named_parameters( dict(testcase_name='no_force_decoding', prompt_with_targets=False), dict(testcase_name='force_decoding', prompt_with_targets=True), ) def test_prompt_with_targets(self, prompt_with_targets): batch_size, encoder_len, max_decode_len, emb_dim = 2, 3, 4, 5 batch = { 'encoder_input_tokens': np.zeros((batch_size, encoder_len), dtype=np.int32), 'decoder_input_tokens': np.full([batch_size, max_decode_len], 2, dtype=np.int32) } # These dummy logits represent the probability distribution where all the # probability mass is in one item (i.e., degenerate distribution). For # batch element 0, it is vocabulary index 3. # We test `_predict_step` to avoid having to define a task and its # vocabulary. dummy_logits = jnp.expand_dims( jnp.array([[-1e7, -1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, -1e7, 0]]), axis=1) mock_decode_fn = mock.Mock() mock_decode_fn.return_value = (np.full([batch_size, max_decode_len, 1], 3, dtype=np.int32), np.full([batch_size, 1], 1.0, dtype=np.float32)) class MockModule: def __init__(self): self.dtype = jnp.float32 def apply(self, *args, method=None, **kwargs): del args, kwargs if method is None: # use for module.`__call__` return (dummy_logits, {'cache': {}}) else: return method() def encode(self): return jnp.zeros((batch_size, encoder_len, emb_dim)) def decode(self): return (dummy_logits, {'cache': {}}) def mock_init(self): self.module = MockModule() self.module.scan_layers = False self._input_vocabulary = mock.Mock(eos_id=1) self._output_vocabulary = mock.Mock(eos_id=1) self._decode_fn = mock_decode_fn with mock.patch.object( models.EncoderDecoderModel, '__init__', new=mock_init): model = models.EncoderDecoderModel() model.predict_batch_with_aux({}, batch, prompt_with_targets=prompt_with_targets) if prompt_with_targets: expected_inputs = batch['decoder_input_tokens'] else: expected_inputs = np.zeros([batch_size, max_decode_len], dtype=np.int32) assert mock_decode_fn.call_count == 1 # Look at the kwargs call list for inputs, assert_called_with doesn't # work well with np.array comparison. np.testing.assert_array_equal(mock_decode_fn.mock_calls[0][2]['inputs'], expected_inputs) def test_predict_batch_loop_and_caches_are_equal(self): vocab_size = 50 lengths = np.array([[2], [3]]) batch_size, beam_size, encoder_len, max_decode_len = 2, 2, 3, 7 batch = { 'encoder_input_tokens': np.zeros((batch_size, encoder_len), dtype=np.int32), 'decoder_target_tokens': np.zeros((batch_size, encoder_len), dtype=np.int32), 'decoder_input_tokens': np.concatenate( [ np.expand_dims( np.concatenate( [[0], np.arange(9, 9 + lengths[0][0], dtype=np.int32), np.zeros((max_decode_len - lengths[0][0] - 1), dtype=np.int32)]), axis=0), # First element np.expand_dims( np.concatenate( [[0], np.arange(3, 3 + lengths[1][0], dtype=np.int32), np.zeros((max_decode_len - lengths[1][0] - 1), dtype=np.int32)]), axis=0) # Second element ], axis=0), } model = test_utils.get_t5_test_model(vocab_size=50) module = model.module params = module.init( jax.random.PRNGKey(0), jnp.ones((batch_size, encoder_len)), jnp.ones((batch_size, max_decode_len)), jnp.ones((batch_size, max_decode_len)), enable_dropout=False)['params'] def mock_init(self): self.module = module # Set the EOS token to be larger then the vocabulary size. This forces the # model to decode all the way to `max_decode_length`, allowing us to test # behavior when one element reaches the end before the others. self._output_vocabulary = mock.Mock(eos_id=vocab_size + 12) self._decode_fn = decoding.beam_search with mock.patch.object( models.EncoderDecoderModel, '__init__', new=mock_init): model = models.EncoderDecoderModel() with mock.patch.object( model, '_compute_logits_from_slice', autospec=True) as tokens_to_logits_mock: # Make the side effect of the mock, call the method on the class, with the # instance partialed in as `self`. This lets us call the actual code, # while recording the inputs, without an infinite loop you would get # calling `instance.method` tokens_to_logits_mock.side_effect = functools.partial( models.EncoderDecoderModel._compute_logits_from_slice, model) # Disable jit, so that the `lax.while_loop` isn't traced, as the # collection of tracers in the mock call_args would generally trigger a # tracer leak error. with jax.disable_jit(): _ = model.predict_batch_with_aux( params, batch, prompt_with_targets=True, num_decodes=2) # Collect all the input tokens to our tokens_to_logits function all_inputs = [] all_cache_keys = [] # Collect all the cache keys all_cache_values = [] # Collect all the cache values # Currently force decoding generates logits at every step. We should have # `max_decode_length` calls to our tokens -> logits func. self.assertLen(tokens_to_logits_mock.call_args_list, max_decode_len) for tokens_call in tokens_to_logits_mock.call_args_list: # Inputs: [B * Be, 1] inputs, cache = tokens_call[0] cache = flax.core.unfreeze(cache) # Cache: [B * Be, 1] * #Layers cache_keys = [ v for k, v in traverse_util.flatten_dict(cache).items() if k[-1] == 'cached_key' ] cache_values = [ v for k, v in traverse_util.flatten_dict(cache).items() if k[-1] == 'cached_value' ] all_inputs.append(inputs) all_cache_keys.append(cache_keys) all_cache_values.append(cache_values) # Convert inputs to a single block [B, DL, Be] all_inputs = np.concatenate(all_inputs, axis=1) # Convert caches into a single block per layer [B * Be, DL] * L all_cache_keys = [np.stack(c, axis=1) for c in zip(*all_cache_keys)] all_cache_values = [np.stack(c, axis=1) for c in zip(*all_cache_values)] # Make sure that for each batch, the cache for each beam is identical when # prompt is being forced. for b in range(batch_size): for i, input_token in enumerate(all_inputs[b * beam_size]): if i < lengths[b]: self.assertEqual(input_token, batch['decoder_input_tokens'][b][i]) # For all layers. for cache_keys in all_cache_keys: np.testing.assert_array_equal(cache_keys[b * beam_size][i], cache_keys[b * beam_size + 1][i]) for cache_values in all_cache_values: np.testing.assert_array_equal(cache_values[b * beam_size][i], cache_values[b * beam_size + 1][i]) def test_score_batch(self): encoder_input_tokens = jnp.ones((2, 3)) # For this test, decoder input and target tokens are dummy values. decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]) logits = jnp.arange(0, 24).reshape((2, 4, 3)) params = {'foo': jnp.zeros(3)} mock_transformer = mock.Mock() mock_transformer.apply.return_value = logits mock_transformer.dtype = jnp.float32 batch = { 'encoder_input_tokens': encoder_input_tokens, 'decoder_input_tokens': decoder_input_tokens, 'decoder_target_tokens': decoder_target_tokens, 'decoder_loss_weights': decoder_loss_weights } def mock_init(self): self.module = mock_transformer with mock.patch.object( models.EncoderDecoderModel, '__init__', new=mock_init): model = models.EncoderDecoderModel() res = model.score_batch(params, batch) mock_transformer.apply.assert_called_with({'params': params}, encoder_input_tokens, decoder_input_tokens, decoder_target_tokens, encoder_segment_ids=None, decoder_segment_ids=None, encoder_positions=None, decoder_positions=None, decode=False, enable_dropout=False, rngs=None, mutable=False) np.testing.assert_allclose(res, [-3.222973, -1.815315], rtol=1e-4) def test_score_batch_can_return_intermediates(self): encoder_input_tokens = jnp.ones((2, 3)) # For this test, decoder input and target tokens are dummy values. decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]) logits = jnp.arange(0, 24).reshape((2, 4, 3)) modified_variables = {'intermediates': {'bar': jnp.ones(5)}} params = {'foo': jnp.zeros(3)} mock_transformer = mock.Mock() mock_transformer.apply.return_value = (logits, modified_variables) mock_transformer.dtype = jnp.float32 batch = { 'encoder_input_tokens': encoder_input_tokens, 'decoder_input_tokens': decoder_input_tokens, 'decoder_target_tokens': decoder_target_tokens, 'decoder_loss_weights': decoder_loss_weights } def mock_init(self): self.module = mock_transformer with mock.patch.object( models.EncoderDecoderModel, '__init__', new=mock_init): model = models.EncoderDecoderModel() scores, intermediates = model.score_batch( params, batch, return_intermediates=True) mock_transformer.apply.assert_called_with({'params': params}, encoder_input_tokens, decoder_input_tokens, decoder_target_tokens, encoder_segment_ids=None, decoder_segment_ids=None, encoder_positions=None, decoder_positions=None, decode=False, enable_dropout=False, rngs=None, mutable=['intermediates']) np.testing.assert_allclose(scores, [-3.222973, -1.815315], rtol=1e-4) # Incumbent intermediates are passed out unchanged. np.testing.assert_allclose(intermediates['bar'], jnp.ones(5)) # A new collection of decoder intermediates are inserted by score_batch() np.testing.assert_allclose(intermediates['decoder']['loss_weights'][0], decoder_loss_weights) np.testing.assert_allclose(intermediates['decoder']['target_tokens'][0], decoder_target_tokens) def test_train_transformer_wmt(self): # Dummy input data input_shape = (16, 8) encoder_input_tokens = np.ones(shape=input_shape, dtype=np.float32) decoder_input_tokens = 5 * np.ones(shape=input_shape, dtype=np.float32) decoder_target_tokens = 5 * np.ones(input_shape, dtype=np.float32) # input_data = {'inputs': inputs, 'targets': targets} input_data = { 'encoder_input_tokens': encoder_input_tokens, 'decoder_input_tokens': decoder_input_tokens, 'decoder_target_tokens': decoder_target_tokens } partitioner = partitioning.PjitPartitioner(num_partitions=1) model = test_utils.get_t5_test_model() ds_iter = tf.data.Dataset.from_tensors(input_data).as_numpy_iterator() input_shapes = {k: input_shape for k in input_data} train_state_initializer = utils.TrainStateInitializer( optimizer_def=model.optimizer_def, init_fn=model.get_initial_variables, input_shapes=input_shapes, partitioner=partitioner) train_state_axes = train_state_initializer.train_state_axes train_state = train_state_initializer.from_scratch(jax.random.PRNGKey(0)) trainer = trainer_lib.Trainer( model, train_state=train_state, partitioner=partitioner, eval_names=[], summary_dir=None, train_state_axes=train_state_axes, rng=jax.random.PRNGKey(0), learning_rate_fn=lambda x: 0.001, num_microbatches=1) trainer.train(ds_iter, 1) logging.info('optimizer after first step %s', train_state.params) @parameterized.parameters( {'decode_fn': decoding.beam_search}, {'decode_fn': functools.partial(decoding.temperature_sample, topk=4)}) def test_predict_batch(self, decode_fn): batch_size, encoder_len, max_decode_len, emb_dim = 2, 3, 4, 5 batch = { 'encoder_input_tokens': np.zeros((batch_size, encoder_len), dtype=np.int32), 'decoder_input_tokens': np.zeros((batch_size, max_decode_len), dtype=np.int32) } # These dummy logits represent the probability distribution where all the # probability mass is in one item (i.e., degenerate distribution). For # batch element 0, it is vocabulary index 2. # We test `_predict_step` to avoid having to define a task and its # vocabulary. dummy_logits = jnp.expand_dims( jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1) class MockModule: def __init__(self): self.dtype = jnp.float32 def apply(self, *args, method=None, **kwargs): del args, kwargs if method is None: # use for module.`__call__` return (dummy_logits, {'cache': {}}) else: return method() def encode(self): return jnp.zeros((batch_size, encoder_len, emb_dim)) def decode(self): return (dummy_logits, {'cache': {}}) def mock_init(self): self.module = MockModule() self.module.scan_layers = False self._input_vocabulary = mock.Mock(eos_id=1) self._output_vocabulary = mock.Mock(eos_id=1) self._decode_fn = decode_fn with mock.patch.object( models.EncoderDecoderModel, '__init__', new=mock_init): model = models.EncoderDecoderModel() actual = model.predict_batch({}, batch) # The predicted token for the first batch element is always 2 and it is 3 # for the second batch element. expected = [[2] * max_decode_len, [3] * max_decode_len] np.testing.assert_array_equal(actual, expected) def test_predict_batch_rng(self): batch = { 'encoder_input_tokens': np.zeros((2, 1), dtype=np.int32), 'decoder_input_tokens': np.zeros((2, 2), dtype=np.int32) } decode_fn_mock = mock.Mock( return_value=(np.zeros((2, 2, 3)), np.zeros((2, 2)))) def mock_init(self): self.module = mock.Mock( apply=mock.Mock(side_effect=lambda *_, **kwargs: ( # pylint:disable=g-long-lambda,g-long-ternary np.zeros((2, 2)), { 'cache': None }) if 'mutable' in kwargs else np.zeros((2, 2)))) self._output_vocabulary = mock.Mock(eos_id=1) self._decode_fn = decode_fn_mock with mock.patch.object( models.EncoderDecoderModel, '__init__', new=mock_init): model = models.EncoderDecoderModel() # No RNG model.predict_batch({}, batch) _, decode_fn_kwargs = decode_fn_mock.call_args self.assertNotIn('decode_rng', decode_fn_kwargs) # No RNG (w/ aux) model.predict_batch_with_aux({}, batch) _, decode_fn_kwargs = decode_fn_mock.call_args self.assertNotIn('decode_rng', decode_fn_kwargs) # decoder_params RNG model.predict_batch_with_aux({}, batch, decoder_params={'decode_rng': 3}) _, decode_fn_kwargs = decode_fn_mock.call_args self.assertEqual(decode_fn_kwargs['decode_rng'], 3) # rng RNG model.predict_batch({}, batch, rng=4) _, decode_fn_kwargs = decode_fn_mock.call_args self.assertEqual(decode_fn_kwargs['decode_rng'], 4) # rng RNG (w/ aux) model.predict_batch_with_aux({}, batch, rng=4) _, decode_fn_kwargs = decode_fn_mock.call_args self.assertEqual(decode_fn_kwargs['decode_rng'], 4) # Both with self.assertRaisesWithLiteralMatch( ValueError, 'Got RNG both from the `rng` argument (4) and ' "`decoder_params['decode_rng']` (3). Please specify one or the other."): model.predict_batch_with_aux({}, batch, rng=4, decoder_params={'decode_rng': 3}) @parameterized.named_parameters( dict( testcase_name='int32', batch={ 'encoder_input_tokens': np.zeros((BATCH_SIZE, ENCODER_LEN), dtype=np.int32), 'decoder_input_tokens': np.zeros((BATCH_SIZE, MAX_DECODE_LEN), dtype=np.int32) }), dict( testcase_name='float32', batch={ 'encoder_input_tokens': np.zeros((BATCH_SIZE, ENCODER_LEN), dtype=np.float32), 'decoder_input_tokens': np.zeros((BATCH_SIZE, MAX_DECODE_LEN), dtype=np.float32) })) def test_predict_batch_fake_input_shapes_and_types(self, batch): # These dummy logits represent the probability distribution where all the # probability mass is in one item (i.e., degenerate distribution). For # batch element 0, it is vocabulary index 2. # We test `_predict_step` to avoid having to define a task and its # vocabulary. dummy_logits = jnp.ones((2, 1, 4), jnp.float32) class MockModule: def __init__(self): self.dtype = jnp.float32 self.call_args_list = [] def apply(self, *args, method=None, **kwargs): # Not sure why this isn't a real Mock so just record the args/kwargs self.call_args_list.append({'args': args, 'kwargs': kwargs}) del args, kwargs if method is None: # use for module.`__call__` return (dummy_logits, {'cache': {}}) else: return method() def encode(self): return jnp.zeros((BATCH_SIZE, ENCODER_LEN, EMBED_DIM)) def decode(self): return (dummy_logits, {'cache': {}}) def mock_init(self): self.module = MockModule() self.module.scan_layers = False self._input_vocabulary = mock.Mock(eos_id=1) self._output_vocabulary = mock.Mock(eos_id=1) self._decode_fn = decoding.beam_search self._inputs_bidirectional_attention = False with mock.patch.object( models.EncoderDecoderModel, '__init__', new=mock_init): model = models.EncoderDecoderModel() model.predict_batch({}, batch) fake_inputs = jnp.ones_like(batch['encoder_input_tokens']) fake_target = jnp.ones_like(batch['decoder_input_tokens']) cache_init_call = model.module.call_args_list[0] self.assertEqual(cache_init_call['args'][0], {'params': {}}) np.testing.assert_allclose(cache_init_call['args'][1], fake_inputs) np.testing.assert_allclose(cache_init_call['args'][2], fake_target) np.testing.assert_allclose(cache_init_call['args'][3], fake_target) self.assertEqual(cache_init_call['kwargs'], { 'decode': True, 'enable_dropout': False, 'mutable': ['cache'] }) class DecoderOnlyModelTest(parameterized.TestCase): def test_predict_batch_visible_in_prefill(self): batch_size = 2 seq_len = 10 lengths = np.array([[6], [3]]) batch = { 'decoder_input_tokens': np.tile( np.expand_dims(np.arange(seq_len, dtype=np.int32), axis=0), (batch_size, 1)), 'decoder_causal_attention': (lengths > np.arange(seq_len)).astype(np.int32) } dummy_logits = jnp.expand_dims( jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1) mock_module = mock.Mock() mock_module.apply.return_value = (dummy_logits, {'cache': {}}) mock_module.dtype = jnp.float32 def mock_init(self): self.module = mock_module self._output_vocabulary = mock.Mock(eos_id=1) self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) self._inputs_bidirectional_attention = False with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): model = models.DecoderOnlyModel() model.predict_batch({}, batch) prefill_call = mock_module.apply.call_args_list[1] kwargs = prefill_call[1] inputs = prefill_call[1]['decoder_input_tokens'] # Note that, for the prefill call, we use 'decoder_causal_attention' as # 'decoder_target_tokens'. targets = prefill_call[1]['decoder_target_tokens'] self.assertTrue(kwargs['prefill']) np.testing.assert_array_equal(kwargs['prefill_lengths'], np.squeeze(lengths - 1, axis=-1)) # Test that the non padding values of the "targets" cover all of the input, # you it will all be considered in the attention mask. np.testing.assert_array_equal(inputs * targets, inputs) # Check that the first value of the target is 1, the first value of the # inputs is always 0 so the masking check wouldn't catch it if the target # had a 0 in the first location. np.testing.assert_array_equal(targets[:, 0], np.ones_like(targets[:, 0])) # Test that the targets are properly removed. Our input is a sequence from 0 # onward, so our largest value (the last input) should be equal by it's # position (which is 1 - length). If we didn't mask the target correctly, # we would expect a larger value in the max. np.testing.assert_array_equal( np.max(inputs, axis=1), np.squeeze(lengths - 1, axis=-1)) def test_predict_batch(self): batch = { 'decoder_input_tokens': np.array([[0, 3, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0, 0]]), 'decoder_causal_attention': np.array([[1, 1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0]]) } # These dummy logits represent the probability distribution where all the # probability mass is in one item (i.e., degenerate distribution). For # batch element 0, it is vocabulary index 2. # We test `_predict_step` to avoid having to define a task and its # vocabulary. dummy_logits = jnp.expand_dims( jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1) mock_module = mock.Mock() mock_module.apply.return_value = (dummy_logits, {'cache': {}}) mock_module.dtype = jnp.float32 def mock_init(self): self.module = mock_module self._output_vocabulary = mock.Mock(eos_id=1) self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) self._inputs_bidirectional_attention = False with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): model = models.DecoderOnlyModel() actual = model.predict_batch({}, batch) expected = [[2, 2, 2, 2, 2, 0, 0], [3, 3, 3, 3, 3, 3, 0]] # The expected progression of the first element of 'decoder_input_tokens': # [0, 3, 4, 5, 6, 0, 0] -> [0, 3, 4, 0, 0, 0, 0] -> # [3, 4, 2, 2, 2, 2, 2] -> [2, 2, 2, 2, 2, 0, 0] # The expected progression of the second element of 'decoder_input_tokens': # [0, 7, 8, 9, 0, 0, 0] -> [0, 7, 0, 0, 0, 0, 0] -> # [7, 3, 3, 3, 3, 3, 3] -> [3, 3, 3, 3, 3, 3, 0] np.testing.assert_array_equal(actual, expected) def test_predict_batch_rng(self): batch = { 'decoder_input_tokens': np.zeros((2, 2), dtype=np.int32), 'decoder_causal_attention': np.zeros((2, 2), dtype=np.int32) } decode_fn_mock = mock.Mock( return_value=(np.zeros((2, 2, 3)), np.zeros((2, 2)))) def mock_init(self): self.module = mock.Mock( apply=mock.Mock(side_effect=lambda *_, **kwargs: ( # pylint:disable=g-long-lambda,g-long-ternary np.zeros((2, 2)), { 'cache': None }) if 'mutable' in kwargs else np.zeros((2, 2)))) self._output_vocabulary = mock.Mock(eos_id=1) self._decode_fn = decode_fn_mock self._inputs_bidirectional_attention = False with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): model = models.DecoderOnlyModel() # No RNG model.predict_batch({}, batch) _, decode_fn_kwargs = decode_fn_mock.call_args self.assertNotIn('decode_rng', decode_fn_kwargs) # No RNG (w/ aux) model.predict_batch_with_aux({}, batch) _, decode_fn_kwargs = decode_fn_mock.call_args self.assertNotIn('decode_rng', decode_fn_kwargs) # decoder_params RNG model.predict_batch_with_aux({}, batch, decoder_params={'decode_rng': 3}) _, decode_fn_kwargs = decode_fn_mock.call_args self.assertEqual(decode_fn_kwargs['decode_rng'], 3) # rng RNG model.predict_batch({}, batch, rng=4) _, decode_fn_kwargs = decode_fn_mock.call_args self.assertEqual(decode_fn_kwargs['decode_rng'], 4) # rng RNG (w/ aux) model.predict_batch_with_aux({}, batch, rng=4) _, decode_fn_kwargs = decode_fn_mock.call_args self.assertEqual(decode_fn_kwargs['decode_rng'], 4) # Both with self.assertRaisesWithLiteralMatch( ValueError, 'Got RNG both from the `rng` argument (4) and ' "`decoder_params['decode_rng']` (3). Please specify one or the other."): model.predict_batch_with_aux({}, batch, rng=4, decoder_params={'decode_rng': 3}) def test_predict_batch_num_decodes_temperature_sample(self): batch = { 'decoder_input_tokens': np.array([ [0, 3, 4, 5, 6, 0, 0], ]), 'decoder_causal_attention': np.array([ [1, 1, 1, 0, 0, 0, 0], ]) } # These dummy logits represent the probability distribution where all the # probability mass is in one item (i.e., degenerate distribution). For # batch element 0, it is vocabulary index 2. We have two samples. # Technically these should be identical since the prompts are the same, but # this makes testing easier. dummy_logits = jnp.expand_dims( jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1) mock_module = mock.Mock() mock_module.apply.return_value = (dummy_logits, {'cache': {}}) mock_module.dtype = jnp.float32 def mock_init(self): self.module = mock_module self._output_vocabulary = mock.Mock(eos_id=1) self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) self._inputs_bidirectional_attention = False with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): model = models.DecoderOnlyModel() actual_output, aux = model.predict_batch_with_aux({}, batch, num_decodes=2, return_all_decodes=True) expected_output = [[[2, 2, 2, 2, 2, 0, 0], [3, 3, 3, 3, 3, 0, 0]]] expected_scores = [[0., 0.]] # The expected progression of the first element of 'decoder_input_tokens': # [0, 3, 4, 5, 6, 0, 0] -> [0, 3, 4, 0, 0, 0, 0] -> # [3, 4, 2, 2, 2, 2, 2] -> [2, 2, 2, 2, 2, 0, 0] # The expected progression of the second element of 'decoder_input_tokens': # [0, 7, 8, 9, 0, 0, 0] -> [0, 7, 0, 0, 0, 0, 0] -> # [7, 3, 3, 3, 3, 3, 3] -> [3, 3, 3, 3, 3, 3, 0] np.testing.assert_array_equal(actual_output, expected_output) np.testing.assert_array_equal(aux['scores'], expected_scores) def test_predict_batch_fake_input_shapes_and_types(self): # The input and causal attention actually have to be int32 for this test, # even though the cache init should work with any types the `inputs` that # is created from multiplying the causal attention and the input tokens # needs to be an int or the decoding will fail. batch = { 'decoder_input_tokens': np.array([[0, 3, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0, 0]], dtype=np.int32), 'decoder_causal_attention': np.array([[1, 1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0]], dtype=np.int32) } dummy_logits = jnp.ones((2, 1, 5), jnp.float32) mock_module = mock.Mock() mock_module.apply.return_value = (dummy_logits, {'cache': {}}) mock_module.dtype = jnp.float32 def mock_init(self): self.module = mock_module self._output_vocabulary = mock.Mock(eos_id=1) self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) self._inputs_bidirectional_attention = False with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): model = models.DecoderOnlyModel() model.predict_batch({}, batch) fake_target = jnp.ones_like(batch['decoder_input_tokens']) cache_init_call = mock_module.apply.call_args_list[0] self.assertEqual(cache_init_call[0][0], {'params': {}}) np.testing.assert_allclose(cache_init_call[0][1], fake_target) np.testing.assert_allclose(cache_init_call[0][2], fake_target) self.assertEqual(cache_init_call[1], { 'decode': True, 'enable_dropout': False, 'mutable': ['cache'] }) @parameterized.named_parameters( dict( testcase_name='no_types', shapes={'decoder_input_tokens': [1, 62]}, types=None), dict( testcase_name='int32', shapes={'decoder_input_tokens': [1, 62]}, types={'decoder_input_tokens': jnp.int32}), dict( testcase_name='float32', shapes={'decoder_input_tokens': [1, 62]}, types={'decoder_input_tokens': jnp.int32}), ) def test_get_initial_variables_shapes_and_types(self, shapes, types): mock_lm = mock.Mock() mock_lm.init.return_value = {'params': {}} mock_optimizer_def = mock.Mock() rng = mock.Mock() def mock_init(self): self.module = mock_lm self.optimizer_def = mock_optimizer_def with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): model = models.DecoderOnlyModel() model.get_initial_variables(rng, shapes, types) if types is None: decoder_input = jnp.ones( shapes['decoder_input_tokens'], dtype=jnp.float32) else: decoder_input = jnp.ones( shapes['decoder_input_tokens'], dtype=types['decoder_input_tokens']) # Using `.assert_called_once_with` doesn't work because the simple # comparison it does for the array arguments fail (truth value of an array # is ambiguous). called_with = mock_lm.init.call_args self.assertEqual(called_with[0][0], rng) np.testing.assert_allclose(called_with[0][1], decoder_input) np.testing.assert_allclose(called_with[0][2], decoder_input) self.assertEqual(mock_lm.init.call_args[1], {'enable_dropout': False}) if __name__ == '__main__': absltest.main()