youtube-music-transcribe / t5x /models_test.py
juancopi81's picture
Add t5x and mt3 models
b100e1c
raw
history blame
38.9 kB
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for t5x.models."""
import functools
from unittest import mock
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
import flax
from flax import traverse_util
import jax
import jax.numpy as jnp
import numpy as np
import t5.data.tasks # pylint:disable=unused-import
from t5x import decoding
from t5x import models
from t5x import partitioning
from t5x import test_utils
from t5x import trainer as trainer_lib
from t5x import utils
import tensorflow as tf
# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()
PartitionSpec = partitioning.PartitionSpec
class ModelsTest(parameterized.TestCase):
def test_remove_prefix(self):
sequences = np.array([[1, 2, 3, 4, 5, 6, 7, 0], [6, 7, 8, 9, 10, 11, 0, 0]])
prefix_lengths = np.array([2, 4])
expected = [[3, 4, 5, 6, 7, 0, 0, 0], [10, 11, 0, 0, 0, 0, 0, 0]]
remove_prefix = jax.jit(models.remove_prefix)
actual = remove_prefix(sequences, prefix_lengths)
np.testing.assert_array_equal(actual, expected)
def test_remove_prefix_zero_len_prefix(self):
sequences = np.array([[1, 2, 3, 4, 5, 6, 7, 0], [6, 7, 8, 9, 10, 11, 0, 0]])
prefix_lengths = np.array([0, 0])
remove_prefix = jax.jit(models.remove_prefix)
actual = remove_prefix(sequences, prefix_lengths)
# The expected output is the original sequences.
np.testing.assert_array_equal(actual, sequences)
BATCH_SIZE, ENCODER_LEN, MAX_DECODE_LEN, EMBED_DIM = 2, 3, 4, 5
class EncoderDecoderModelTest(parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='no_types',
shapes={
'encoder_input_tokens': [1, 512],
'decoder_input_tokens': [1, 62]
},
types=None),
dict(
testcase_name='int32',
shapes={
'encoder_input_tokens': [1, 512],
'decoder_input_tokens': [1, 62]
},
types={
'encoder_input_tokens': jnp.int32,
'decoder_input_tokens': jnp.int32
}),
dict(
testcase_name='float32',
shapes={
'encoder_input_tokens': [1, 512],
'decoder_input_tokens': [1, 62],
'encoder_positions': [1, 512],
'decoder_positions': [1, 62],
},
types={
'encoder_input_tokens': jnp.int32,
'decoder_input_tokens': jnp.int32,
'encoder_positions': jnp.int32,
'decoder_positions': jnp.int32
}),
dict(
testcase_name='float32_segment_ids',
shapes={
'encoder_input_tokens': [1, 512],
'decoder_input_tokens': [1, 62],
'encoder_segment_ids': [1, 512],
'decoder_segment_ids': [1, 62],
},
types={
'encoder_input_tokens': jnp.int32,
'decoder_input_tokens': jnp.int32,
'encoder_segment_ids': jnp.int32,
'decoder_segment_ids': jnp.int32
}),
)
def test_get_initial_variables_shapes_and_types(self, shapes, types):
mock_transformer = mock.Mock()
mock_transformer.init.return_value = {'params': {}}
mock_optimizer_def = mock.Mock()
rng = mock.Mock()
def mock_init(self):
self.module = mock_transformer
self.optimizer_def = mock_optimizer_def
with mock.patch.object(
models.EncoderDecoderModel, '__init__', new=mock_init):
model = models.EncoderDecoderModel()
model.get_initial_variables(rng, shapes, types)
if types is None:
encoder_input = jnp.ones(
shapes['encoder_input_tokens'], dtype=jnp.float32)
decoder_input = jnp.ones(
shapes['decoder_input_tokens'], dtype=jnp.float32)
else:
encoder_input = jnp.ones(
shapes['encoder_input_tokens'], dtype=types['encoder_input_tokens'])
decoder_input = jnp.ones(
shapes['decoder_input_tokens'], dtype=types['decoder_input_tokens'])
# Using `.assert_called_once_with` doesn't work because the simple
# comparison it does for the array arguments fail (truth value of an array
# is ambiguous).
called_with = mock_transformer.init.call_args
self.assertEqual(called_with[0][0], rng)
np.testing.assert_allclose(called_with[0][1], encoder_input)
np.testing.assert_allclose(called_with[0][2], decoder_input)
np.testing.assert_allclose(called_with[0][3], decoder_input)
if 'encoder_positions' in shapes:
encoder_positions = jnp.ones(
shapes['encoder_positions'], dtype=types['encoder_positions'])
np.testing.assert_allclose(called_with[1]['encoder_positions'],
encoder_positions)
else:
self.assertIsNone(called_with[1]['encoder_positions'])
if 'decoder_positions' in shapes:
decoder_positions = jnp.ones(
shapes['decoder_positions'], dtype=types['decoder_positions'])
np.testing.assert_allclose(called_with[1]['decoder_positions'],
decoder_positions)
else:
self.assertIsNone(called_with[1]['decoder_positions'])
if 'encoder_segment_ids' in shapes:
encoder_positions = jnp.ones(
shapes['encoder_segment_ids'], dtype=types['encoder_segment_ids'])
np.testing.assert_allclose(called_with[1]['encoder_segment_ids'],
encoder_positions)
else:
self.assertIsNone(called_with[1]['encoder_segment_ids'])
if 'decoder_segment_ids' in shapes:
decoder_segment_ids = jnp.ones(
shapes['decoder_segment_ids'], dtype=types['decoder_segment_ids'])
np.testing.assert_allclose(called_with[1]['decoder_segment_ids'],
decoder_segment_ids)
else:
self.assertIsNone(called_with[1]['decoder_segment_ids'])
self.assertFalse(called_with[1]['decode'])
self.assertFalse(called_with[1]['enable_dropout'])
@parameterized.named_parameters(
dict(testcase_name='no_force_decoding', prompt_with_targets=False),
dict(testcase_name='force_decoding', prompt_with_targets=True),
)
def test_prompt_with_targets(self, prompt_with_targets):
batch_size, encoder_len, max_decode_len, emb_dim = 2, 3, 4, 5
batch = {
'encoder_input_tokens':
np.zeros((batch_size, encoder_len), dtype=np.int32),
'decoder_input_tokens':
np.full([batch_size, max_decode_len], 2, dtype=np.int32)
}
# These dummy logits represent the probability distribution where all the
# probability mass is in one item (i.e., degenerate distribution). For
# batch element 0, it is vocabulary index 3.
# We test `_predict_step` to avoid having to define a task and its
# vocabulary.
dummy_logits = jnp.expand_dims(
jnp.array([[-1e7, -1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, -1e7, 0]]),
axis=1)
mock_decode_fn = mock.Mock()
mock_decode_fn.return_value = (np.full([batch_size, max_decode_len, 1],
3,
dtype=np.int32),
np.full([batch_size, 1],
1.0,
dtype=np.float32))
class MockModule:
def __init__(self):
self.dtype = jnp.float32
def apply(self, *args, method=None, **kwargs):
del args, kwargs
if method is None: # use for module.`__call__`
return (dummy_logits, {'cache': {}})
else:
return method()
def encode(self):
return jnp.zeros((batch_size, encoder_len, emb_dim))
def decode(self):
return (dummy_logits, {'cache': {}})
def mock_init(self):
self.module = MockModule()
self.module.scan_layers = False
self._input_vocabulary = mock.Mock(eos_id=1)
self._output_vocabulary = mock.Mock(eos_id=1)
self._decode_fn = mock_decode_fn
with mock.patch.object(
models.EncoderDecoderModel, '__init__', new=mock_init):
model = models.EncoderDecoderModel()
model.predict_batch_with_aux({},
batch,
prompt_with_targets=prompt_with_targets)
if prompt_with_targets:
expected_inputs = batch['decoder_input_tokens']
else:
expected_inputs = np.zeros([batch_size, max_decode_len], dtype=np.int32)
assert mock_decode_fn.call_count == 1
# Look at the kwargs call list for inputs, assert_called_with doesn't
# work well with np.array comparison.
np.testing.assert_array_equal(mock_decode_fn.mock_calls[0][2]['inputs'],
expected_inputs)
def test_predict_batch_loop_and_caches_are_equal(self):
vocab_size = 50
lengths = np.array([[2], [3]])
batch_size, beam_size, encoder_len, max_decode_len = 2, 2, 3, 7
batch = {
'encoder_input_tokens':
np.zeros((batch_size, encoder_len), dtype=np.int32),
'decoder_target_tokens':
np.zeros((batch_size, encoder_len), dtype=np.int32),
'decoder_input_tokens':
np.concatenate(
[
np.expand_dims(
np.concatenate(
[[0],
np.arange(9, 9 + lengths[0][0], dtype=np.int32),
np.zeros((max_decode_len - lengths[0][0] - 1),
dtype=np.int32)]),
axis=0), # First element
np.expand_dims(
np.concatenate(
[[0],
np.arange(3, 3 + lengths[1][0], dtype=np.int32),
np.zeros((max_decode_len - lengths[1][0] - 1),
dtype=np.int32)]),
axis=0) # Second element
],
axis=0),
}
model = test_utils.get_t5_test_model(vocab_size=50)
module = model.module
params = module.init(
jax.random.PRNGKey(0),
jnp.ones((batch_size, encoder_len)),
jnp.ones((batch_size, max_decode_len)),
jnp.ones((batch_size, max_decode_len)),
enable_dropout=False)['params']
def mock_init(self):
self.module = module
# Set the EOS token to be larger then the vocabulary size. This forces the
# model to decode all the way to `max_decode_length`, allowing us to test
# behavior when one element reaches the end before the others.
self._output_vocabulary = mock.Mock(eos_id=vocab_size + 12)
self._decode_fn = decoding.beam_search
with mock.patch.object(
models.EncoderDecoderModel, '__init__', new=mock_init):
model = models.EncoderDecoderModel()
with mock.patch.object(
model, '_compute_logits_from_slice',
autospec=True) as tokens_to_logits_mock:
# Make the side effect of the mock, call the method on the class, with the
# instance partialed in as `self`. This lets us call the actual code,
# while recording the inputs, without an infinite loop you would get
# calling `instance.method`
tokens_to_logits_mock.side_effect = functools.partial(
models.EncoderDecoderModel._compute_logits_from_slice, model)
# Disable jit, so that the `lax.while_loop` isn't traced, as the
# collection of tracers in the mock call_args would generally trigger a
# tracer leak error.
with jax.disable_jit():
_ = model.predict_batch_with_aux(
params, batch, prompt_with_targets=True, num_decodes=2)
# Collect all the input tokens to our tokens_to_logits function
all_inputs = []
all_cache_keys = [] # Collect all the cache keys
all_cache_values = [] # Collect all the cache values
# Currently force decoding generates logits at every step. We should have
# `max_decode_length` calls to our tokens -> logits func.
self.assertLen(tokens_to_logits_mock.call_args_list, max_decode_len)
for tokens_call in tokens_to_logits_mock.call_args_list:
# Inputs: [B * Be, 1]
inputs, cache = tokens_call[0]
cache = flax.core.unfreeze(cache)
# Cache: [B * Be, 1] * #Layers
cache_keys = [
v for k, v in traverse_util.flatten_dict(cache).items()
if k[-1] == 'cached_key'
]
cache_values = [
v for k, v in traverse_util.flatten_dict(cache).items()
if k[-1] == 'cached_value'
]
all_inputs.append(inputs)
all_cache_keys.append(cache_keys)
all_cache_values.append(cache_values)
# Convert inputs to a single block [B, DL, Be]
all_inputs = np.concatenate(all_inputs, axis=1)
# Convert caches into a single block per layer [B * Be, DL] * L
all_cache_keys = [np.stack(c, axis=1) for c in zip(*all_cache_keys)]
all_cache_values = [np.stack(c, axis=1) for c in zip(*all_cache_values)]
# Make sure that for each batch, the cache for each beam is identical when
# prompt is being forced.
for b in range(batch_size):
for i, input_token in enumerate(all_inputs[b * beam_size]):
if i < lengths[b]:
self.assertEqual(input_token, batch['decoder_input_tokens'][b][i])
# For all layers.
for cache_keys in all_cache_keys:
np.testing.assert_array_equal(cache_keys[b * beam_size][i],
cache_keys[b * beam_size + 1][i])
for cache_values in all_cache_values:
np.testing.assert_array_equal(cache_values[b * beam_size][i],
cache_values[b * beam_size + 1][i])
def test_score_batch(self):
encoder_input_tokens = jnp.ones((2, 3))
# For this test, decoder input and target tokens are dummy values.
decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]])
decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]])
decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]])
logits = jnp.arange(0, 24).reshape((2, 4, 3))
params = {'foo': jnp.zeros(3)}
mock_transformer = mock.Mock()
mock_transformer.apply.return_value = logits
mock_transformer.dtype = jnp.float32
batch = {
'encoder_input_tokens': encoder_input_tokens,
'decoder_input_tokens': decoder_input_tokens,
'decoder_target_tokens': decoder_target_tokens,
'decoder_loss_weights': decoder_loss_weights
}
def mock_init(self):
self.module = mock_transformer
with mock.patch.object(
models.EncoderDecoderModel, '__init__', new=mock_init):
model = models.EncoderDecoderModel()
res = model.score_batch(params, batch)
mock_transformer.apply.assert_called_with({'params': params},
encoder_input_tokens,
decoder_input_tokens,
decoder_target_tokens,
encoder_segment_ids=None,
decoder_segment_ids=None,
encoder_positions=None,
decoder_positions=None,
decode=False,
enable_dropout=False,
rngs=None,
mutable=False)
np.testing.assert_allclose(res, [-3.222973, -1.815315], rtol=1e-4)
def test_score_batch_can_return_intermediates(self):
encoder_input_tokens = jnp.ones((2, 3))
# For this test, decoder input and target tokens are dummy values.
decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]])
decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]])
decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]])
logits = jnp.arange(0, 24).reshape((2, 4, 3))
modified_variables = {'intermediates': {'bar': jnp.ones(5)}}
params = {'foo': jnp.zeros(3)}
mock_transformer = mock.Mock()
mock_transformer.apply.return_value = (logits, modified_variables)
mock_transformer.dtype = jnp.float32
batch = {
'encoder_input_tokens': encoder_input_tokens,
'decoder_input_tokens': decoder_input_tokens,
'decoder_target_tokens': decoder_target_tokens,
'decoder_loss_weights': decoder_loss_weights
}
def mock_init(self):
self.module = mock_transformer
with mock.patch.object(
models.EncoderDecoderModel, '__init__', new=mock_init):
model = models.EncoderDecoderModel()
scores, intermediates = model.score_batch(
params, batch, return_intermediates=True)
mock_transformer.apply.assert_called_with({'params': params},
encoder_input_tokens,
decoder_input_tokens,
decoder_target_tokens,
encoder_segment_ids=None,
decoder_segment_ids=None,
encoder_positions=None,
decoder_positions=None,
decode=False,
enable_dropout=False,
rngs=None,
mutable=['intermediates'])
np.testing.assert_allclose(scores, [-3.222973, -1.815315], rtol=1e-4)
# Incumbent intermediates are passed out unchanged.
np.testing.assert_allclose(intermediates['bar'], jnp.ones(5))
# A new collection of decoder intermediates are inserted by score_batch()
np.testing.assert_allclose(intermediates['decoder']['loss_weights'][0],
decoder_loss_weights)
np.testing.assert_allclose(intermediates['decoder']['target_tokens'][0],
decoder_target_tokens)
def test_train_transformer_wmt(self):
# Dummy input data
input_shape = (16, 8)
encoder_input_tokens = np.ones(shape=input_shape, dtype=np.float32)
decoder_input_tokens = 5 * np.ones(shape=input_shape, dtype=np.float32)
decoder_target_tokens = 5 * np.ones(input_shape, dtype=np.float32)
# input_data = {'inputs': inputs, 'targets': targets}
input_data = {
'encoder_input_tokens': encoder_input_tokens,
'decoder_input_tokens': decoder_input_tokens,
'decoder_target_tokens': decoder_target_tokens
}
partitioner = partitioning.PjitPartitioner(num_partitions=1)
model = test_utils.get_t5_test_model()
ds_iter = tf.data.Dataset.from_tensors(input_data).as_numpy_iterator()
input_shapes = {k: input_shape for k in input_data}
train_state_initializer = utils.TrainStateInitializer(
optimizer_def=model.optimizer_def,
init_fn=model.get_initial_variables,
input_shapes=input_shapes,
partitioner=partitioner)
train_state_axes = train_state_initializer.train_state_axes
train_state = train_state_initializer.from_scratch(jax.random.PRNGKey(0))
trainer = trainer_lib.Trainer(
model,
train_state=train_state,
partitioner=partitioner,
eval_names=[],
summary_dir=None,
train_state_axes=train_state_axes,
rng=jax.random.PRNGKey(0),
learning_rate_fn=lambda x: 0.001,
num_microbatches=1)
trainer.train(ds_iter, 1)
logging.info('optimizer after first step %s', train_state.params)
@parameterized.parameters(
{'decode_fn': decoding.beam_search},
{'decode_fn': functools.partial(decoding.temperature_sample, topk=4)})
def test_predict_batch(self, decode_fn):
batch_size, encoder_len, max_decode_len, emb_dim = 2, 3, 4, 5
batch = {
'encoder_input_tokens':
np.zeros((batch_size, encoder_len), dtype=np.int32),
'decoder_input_tokens':
np.zeros((batch_size, max_decode_len), dtype=np.int32)
}
# These dummy logits represent the probability distribution where all the
# probability mass is in one item (i.e., degenerate distribution). For
# batch element 0, it is vocabulary index 2.
# We test `_predict_step` to avoid having to define a task and its
# vocabulary.
dummy_logits = jnp.expand_dims(
jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1)
class MockModule:
def __init__(self):
self.dtype = jnp.float32
def apply(self, *args, method=None, **kwargs):
del args, kwargs
if method is None: # use for module.`__call__`
return (dummy_logits, {'cache': {}})
else:
return method()
def encode(self):
return jnp.zeros((batch_size, encoder_len, emb_dim))
def decode(self):
return (dummy_logits, {'cache': {}})
def mock_init(self):
self.module = MockModule()
self.module.scan_layers = False
self._input_vocabulary = mock.Mock(eos_id=1)
self._output_vocabulary = mock.Mock(eos_id=1)
self._decode_fn = decode_fn
with mock.patch.object(
models.EncoderDecoderModel, '__init__', new=mock_init):
model = models.EncoderDecoderModel()
actual = model.predict_batch({}, batch)
# The predicted token for the first batch element is always 2 and it is 3
# for the second batch element.
expected = [[2] * max_decode_len, [3] * max_decode_len]
np.testing.assert_array_equal(actual, expected)
def test_predict_batch_rng(self):
batch = {
'encoder_input_tokens': np.zeros((2, 1), dtype=np.int32),
'decoder_input_tokens': np.zeros((2, 2), dtype=np.int32)
}
decode_fn_mock = mock.Mock(
return_value=(np.zeros((2, 2, 3)), np.zeros((2, 2))))
def mock_init(self):
self.module = mock.Mock(
apply=mock.Mock(side_effect=lambda *_, **kwargs: ( # pylint:disable=g-long-lambda,g-long-ternary
np.zeros((2, 2)), {
'cache': None
}) if 'mutable' in kwargs else np.zeros((2, 2))))
self._output_vocabulary = mock.Mock(eos_id=1)
self._decode_fn = decode_fn_mock
with mock.patch.object(
models.EncoderDecoderModel, '__init__', new=mock_init):
model = models.EncoderDecoderModel()
# No RNG
model.predict_batch({}, batch)
_, decode_fn_kwargs = decode_fn_mock.call_args
self.assertNotIn('decode_rng', decode_fn_kwargs)
# No RNG (w/ aux)
model.predict_batch_with_aux({}, batch)
_, decode_fn_kwargs = decode_fn_mock.call_args
self.assertNotIn('decode_rng', decode_fn_kwargs)
# decoder_params RNG
model.predict_batch_with_aux({}, batch, decoder_params={'decode_rng': 3})
_, decode_fn_kwargs = decode_fn_mock.call_args
self.assertEqual(decode_fn_kwargs['decode_rng'], 3)
# rng RNG
model.predict_batch({}, batch, rng=4)
_, decode_fn_kwargs = decode_fn_mock.call_args
self.assertEqual(decode_fn_kwargs['decode_rng'], 4)
# rng RNG (w/ aux)
model.predict_batch_with_aux({}, batch, rng=4)
_, decode_fn_kwargs = decode_fn_mock.call_args
self.assertEqual(decode_fn_kwargs['decode_rng'], 4)
# Both
with self.assertRaisesWithLiteralMatch(
ValueError, 'Got RNG both from the `rng` argument (4) and '
"`decoder_params['decode_rng']` (3). Please specify one or the other."):
model.predict_batch_with_aux({},
batch,
rng=4,
decoder_params={'decode_rng': 3})
@parameterized.named_parameters(
dict(
testcase_name='int32',
batch={
'encoder_input_tokens':
np.zeros((BATCH_SIZE, ENCODER_LEN), dtype=np.int32),
'decoder_input_tokens':
np.zeros((BATCH_SIZE, MAX_DECODE_LEN), dtype=np.int32)
}),
dict(
testcase_name='float32',
batch={
'encoder_input_tokens':
np.zeros((BATCH_SIZE, ENCODER_LEN), dtype=np.float32),
'decoder_input_tokens':
np.zeros((BATCH_SIZE, MAX_DECODE_LEN), dtype=np.float32)
}))
def test_predict_batch_fake_input_shapes_and_types(self, batch):
# These dummy logits represent the probability distribution where all the
# probability mass is in one item (i.e., degenerate distribution). For
# batch element 0, it is vocabulary index 2.
# We test `_predict_step` to avoid having to define a task and its
# vocabulary.
dummy_logits = jnp.ones((2, 1, 4), jnp.float32)
class MockModule:
def __init__(self):
self.dtype = jnp.float32
self.call_args_list = []
def apply(self, *args, method=None, **kwargs):
# Not sure why this isn't a real Mock so just record the args/kwargs
self.call_args_list.append({'args': args, 'kwargs': kwargs})
del args, kwargs
if method is None: # use for module.`__call__`
return (dummy_logits, {'cache': {}})
else:
return method()
def encode(self):
return jnp.zeros((BATCH_SIZE, ENCODER_LEN, EMBED_DIM))
def decode(self):
return (dummy_logits, {'cache': {}})
def mock_init(self):
self.module = MockModule()
self.module.scan_layers = False
self._input_vocabulary = mock.Mock(eos_id=1)
self._output_vocabulary = mock.Mock(eos_id=1)
self._decode_fn = decoding.beam_search
self._inputs_bidirectional_attention = False
with mock.patch.object(
models.EncoderDecoderModel, '__init__', new=mock_init):
model = models.EncoderDecoderModel()
model.predict_batch({}, batch)
fake_inputs = jnp.ones_like(batch['encoder_input_tokens'])
fake_target = jnp.ones_like(batch['decoder_input_tokens'])
cache_init_call = model.module.call_args_list[0]
self.assertEqual(cache_init_call['args'][0], {'params': {}})
np.testing.assert_allclose(cache_init_call['args'][1], fake_inputs)
np.testing.assert_allclose(cache_init_call['args'][2], fake_target)
np.testing.assert_allclose(cache_init_call['args'][3], fake_target)
self.assertEqual(cache_init_call['kwargs'], {
'decode': True,
'enable_dropout': False,
'mutable': ['cache']
})
class DecoderOnlyModelTest(parameterized.TestCase):
def test_predict_batch_visible_in_prefill(self):
batch_size = 2
seq_len = 10
lengths = np.array([[6], [3]])
batch = {
'decoder_input_tokens':
np.tile(
np.expand_dims(np.arange(seq_len, dtype=np.int32), axis=0),
(batch_size, 1)),
'decoder_causal_attention':
(lengths > np.arange(seq_len)).astype(np.int32)
}
dummy_logits = jnp.expand_dims(
jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1)
mock_module = mock.Mock()
mock_module.apply.return_value = (dummy_logits, {'cache': {}})
mock_module.dtype = jnp.float32
def mock_init(self):
self.module = mock_module
self._output_vocabulary = mock.Mock(eos_id=1)
self._decode_fn = functools.partial(decoding.temperature_sample, topk=4)
self._inputs_bidirectional_attention = False
with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init):
model = models.DecoderOnlyModel()
model.predict_batch({}, batch)
prefill_call = mock_module.apply.call_args_list[1]
kwargs = prefill_call[1]
inputs = prefill_call[1]['decoder_input_tokens']
# Note that, for the prefill call, we use 'decoder_causal_attention' as
# 'decoder_target_tokens'.
targets = prefill_call[1]['decoder_target_tokens']
self.assertTrue(kwargs['prefill'])
np.testing.assert_array_equal(kwargs['prefill_lengths'],
np.squeeze(lengths - 1, axis=-1))
# Test that the non padding values of the "targets" cover all of the input,
# you it will all be considered in the attention mask.
np.testing.assert_array_equal(inputs * targets, inputs)
# Check that the first value of the target is 1, the first value of the
# inputs is always 0 so the masking check wouldn't catch it if the target
# had a 0 in the first location.
np.testing.assert_array_equal(targets[:, 0], np.ones_like(targets[:, 0]))
# Test that the targets are properly removed. Our input is a sequence from 0
# onward, so our largest value (the last input) should be equal by it's
# position (which is 1 - length). If we didn't mask the target correctly,
# we would expect a larger value in the max.
np.testing.assert_array_equal(
np.max(inputs, axis=1), np.squeeze(lengths - 1, axis=-1))
def test_predict_batch(self):
batch = {
'decoder_input_tokens':
np.array([[0, 3, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0, 0]]),
'decoder_causal_attention':
np.array([[1, 1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0]])
}
# These dummy logits represent the probability distribution where all the
# probability mass is in one item (i.e., degenerate distribution). For
# batch element 0, it is vocabulary index 2.
# We test `_predict_step` to avoid having to define a task and its
# vocabulary.
dummy_logits = jnp.expand_dims(
jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1)
mock_module = mock.Mock()
mock_module.apply.return_value = (dummy_logits, {'cache': {}})
mock_module.dtype = jnp.float32
def mock_init(self):
self.module = mock_module
self._output_vocabulary = mock.Mock(eos_id=1)
self._decode_fn = functools.partial(decoding.temperature_sample, topk=4)
self._inputs_bidirectional_attention = False
with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init):
model = models.DecoderOnlyModel()
actual = model.predict_batch({}, batch)
expected = [[2, 2, 2, 2, 2, 0, 0], [3, 3, 3, 3, 3, 3, 0]]
# The expected progression of the first element of 'decoder_input_tokens':
# [0, 3, 4, 5, 6, 0, 0] -> [0, 3, 4, 0, 0, 0, 0] ->
# [3, 4, 2, 2, 2, 2, 2] -> [2, 2, 2, 2, 2, 0, 0]
# The expected progression of the second element of 'decoder_input_tokens':
# [0, 7, 8, 9, 0, 0, 0] -> [0, 7, 0, 0, 0, 0, 0] ->
# [7, 3, 3, 3, 3, 3, 3] -> [3, 3, 3, 3, 3, 3, 0]
np.testing.assert_array_equal(actual, expected)
def test_predict_batch_rng(self):
batch = {
'decoder_input_tokens': np.zeros((2, 2), dtype=np.int32),
'decoder_causal_attention': np.zeros((2, 2), dtype=np.int32)
}
decode_fn_mock = mock.Mock(
return_value=(np.zeros((2, 2, 3)), np.zeros((2, 2))))
def mock_init(self):
self.module = mock.Mock(
apply=mock.Mock(side_effect=lambda *_, **kwargs: ( # pylint:disable=g-long-lambda,g-long-ternary
np.zeros((2, 2)), {
'cache': None
}) if 'mutable' in kwargs else np.zeros((2, 2))))
self._output_vocabulary = mock.Mock(eos_id=1)
self._decode_fn = decode_fn_mock
self._inputs_bidirectional_attention = False
with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init):
model = models.DecoderOnlyModel()
# No RNG
model.predict_batch({}, batch)
_, decode_fn_kwargs = decode_fn_mock.call_args
self.assertNotIn('decode_rng', decode_fn_kwargs)
# No RNG (w/ aux)
model.predict_batch_with_aux({}, batch)
_, decode_fn_kwargs = decode_fn_mock.call_args
self.assertNotIn('decode_rng', decode_fn_kwargs)
# decoder_params RNG
model.predict_batch_with_aux({}, batch, decoder_params={'decode_rng': 3})
_, decode_fn_kwargs = decode_fn_mock.call_args
self.assertEqual(decode_fn_kwargs['decode_rng'], 3)
# rng RNG
model.predict_batch({}, batch, rng=4)
_, decode_fn_kwargs = decode_fn_mock.call_args
self.assertEqual(decode_fn_kwargs['decode_rng'], 4)
# rng RNG (w/ aux)
model.predict_batch_with_aux({}, batch, rng=4)
_, decode_fn_kwargs = decode_fn_mock.call_args
self.assertEqual(decode_fn_kwargs['decode_rng'], 4)
# Both
with self.assertRaisesWithLiteralMatch(
ValueError, 'Got RNG both from the `rng` argument (4) and '
"`decoder_params['decode_rng']` (3). Please specify one or the other."):
model.predict_batch_with_aux({},
batch,
rng=4,
decoder_params={'decode_rng': 3})
def test_predict_batch_num_decodes_temperature_sample(self):
batch = {
'decoder_input_tokens': np.array([
[0, 3, 4, 5, 6, 0, 0],
]),
'decoder_causal_attention': np.array([
[1, 1, 1, 0, 0, 0, 0],
])
}
# These dummy logits represent the probability distribution where all the
# probability mass is in one item (i.e., degenerate distribution). For
# batch element 0, it is vocabulary index 2. We have two samples.
# Technically these should be identical since the prompts are the same, but
# this makes testing easier.
dummy_logits = jnp.expand_dims(
jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1)
mock_module = mock.Mock()
mock_module.apply.return_value = (dummy_logits, {'cache': {}})
mock_module.dtype = jnp.float32
def mock_init(self):
self.module = mock_module
self._output_vocabulary = mock.Mock(eos_id=1)
self._decode_fn = functools.partial(decoding.temperature_sample, topk=4)
self._inputs_bidirectional_attention = False
with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init):
model = models.DecoderOnlyModel()
actual_output, aux = model.predict_batch_with_aux({},
batch,
num_decodes=2,
return_all_decodes=True)
expected_output = [[[2, 2, 2, 2, 2, 0, 0], [3, 3, 3, 3, 3, 0, 0]]]
expected_scores = [[0., 0.]]
# The expected progression of the first element of 'decoder_input_tokens':
# [0, 3, 4, 5, 6, 0, 0] -> [0, 3, 4, 0, 0, 0, 0] ->
# [3, 4, 2, 2, 2, 2, 2] -> [2, 2, 2, 2, 2, 0, 0]
# The expected progression of the second element of 'decoder_input_tokens':
# [0, 7, 8, 9, 0, 0, 0] -> [0, 7, 0, 0, 0, 0, 0] ->
# [7, 3, 3, 3, 3, 3, 3] -> [3, 3, 3, 3, 3, 3, 0]
np.testing.assert_array_equal(actual_output, expected_output)
np.testing.assert_array_equal(aux['scores'], expected_scores)
def test_predict_batch_fake_input_shapes_and_types(self):
# The input and causal attention actually have to be int32 for this test,
# even though the cache init should work with any types the `inputs` that
# is created from multiplying the causal attention and the input tokens
# needs to be an int or the decoding will fail.
batch = {
'decoder_input_tokens':
np.array([[0, 3, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0, 0]],
dtype=np.int32),
'decoder_causal_attention':
np.array([[1, 1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0]],
dtype=np.int32)
}
dummy_logits = jnp.ones((2, 1, 5), jnp.float32)
mock_module = mock.Mock()
mock_module.apply.return_value = (dummy_logits, {'cache': {}})
mock_module.dtype = jnp.float32
def mock_init(self):
self.module = mock_module
self._output_vocabulary = mock.Mock(eos_id=1)
self._decode_fn = functools.partial(decoding.temperature_sample, topk=4)
self._inputs_bidirectional_attention = False
with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init):
model = models.DecoderOnlyModel()
model.predict_batch({}, batch)
fake_target = jnp.ones_like(batch['decoder_input_tokens'])
cache_init_call = mock_module.apply.call_args_list[0]
self.assertEqual(cache_init_call[0][0], {'params': {}})
np.testing.assert_allclose(cache_init_call[0][1], fake_target)
np.testing.assert_allclose(cache_init_call[0][2], fake_target)
self.assertEqual(cache_init_call[1], {
'decode': True,
'enable_dropout': False,
'mutable': ['cache']
})
@parameterized.named_parameters(
dict(
testcase_name='no_types',
shapes={'decoder_input_tokens': [1, 62]},
types=None),
dict(
testcase_name='int32',
shapes={'decoder_input_tokens': [1, 62]},
types={'decoder_input_tokens': jnp.int32}),
dict(
testcase_name='float32',
shapes={'decoder_input_tokens': [1, 62]},
types={'decoder_input_tokens': jnp.int32}),
)
def test_get_initial_variables_shapes_and_types(self, shapes, types):
mock_lm = mock.Mock()
mock_lm.init.return_value = {'params': {}}
mock_optimizer_def = mock.Mock()
rng = mock.Mock()
def mock_init(self):
self.module = mock_lm
self.optimizer_def = mock_optimizer_def
with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init):
model = models.DecoderOnlyModel()
model.get_initial_variables(rng, shapes, types)
if types is None:
decoder_input = jnp.ones(
shapes['decoder_input_tokens'], dtype=jnp.float32)
else:
decoder_input = jnp.ones(
shapes['decoder_input_tokens'], dtype=types['decoder_input_tokens'])
# Using `.assert_called_once_with` doesn't work because the simple
# comparison it does for the array arguments fail (truth value of an array
# is ambiguous).
called_with = mock_lm.init.call_args
self.assertEqual(called_with[0][0], rng)
np.testing.assert_allclose(called_with[0][1], decoder_input)
np.testing.assert_allclose(called_with[0][2], decoder_input)
self.assertEqual(mock_lm.init.call_args[1], {'enable_dropout': False})
if __name__ == '__main__':
absltest.main()