Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for t5x.optimizers.""" | |
import dataclasses | |
import functools | |
import operator | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
import chex | |
import flax | |
from flax.core import frozen_dict | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import optax | |
import seqio | |
from t5x import models | |
from t5x import optimizers | |
from t5x import partitioning | |
from t5x import test_utils | |
from t5x import trainer | |
from t5x import utils | |
from t5x.examples.t5 import network | |
def _assert_numpy_allclose(a, b, atol=None, rtol=None): | |
a, b = jnp.array(a), jnp.array(b) | |
a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a | |
b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b | |
kw = {} | |
if atol: | |
kw['atol'] = atol | |
if rtol: | |
kw['rtol'] = rtol | |
np.testing.assert_allclose(a, b, **kw) | |
def check_eq(xs, ys, atol=None, rtol=None): | |
xs_leaves, xs_tree = jax.tree_flatten(xs) | |
ys_leaves, ys_tree = jax.tree_flatten(ys) | |
assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}" | |
assert jax.tree_util.tree_all( | |
jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape, | |
xs_leaves, ys_leaves)), "Leaves' shapes don't match." | |
assert jax.tree_multimap( | |
functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol), | |
xs_leaves, ys_leaves) | |
def flattened_state_dict(x): | |
s = flax.serialization.to_state_dict(x) | |
return flax.traverse_util.flatten_dict(s, sep='/') | |
def tree_shape(x): | |
return jax.tree_map(jnp.shape, x) | |
def tree_equals(x, y): | |
return jax.tree_util.tree_all(jax.tree_multimap(operator.eq, x, y)) | |
def get_fake_tokenized_dataset_no_pretokenized(*_, split='validation', **__): | |
return test_utils.get_fake_tokenized_dataset(split=split).map( | |
lambda x: {k: v for k, v in x.items() if not k.endswith('_pretokenized')}) | |
def get_t5_test_model(optimizer_def, | |
**config_overrides) -> models.EncoderDecoderModel: | |
"""Returns a tiny T5 1.1 model to use for testing.""" | |
tiny_config = network.T5Config( | |
vocab_size=128, | |
dtype='bfloat16', | |
emb_dim=8, | |
num_heads=4, | |
num_encoder_layers=2, | |
num_decoder_layers=2, | |
head_dim=3, | |
mlp_dim=16, | |
mlp_activations=('gelu', 'linear'), | |
dropout_rate=0.0, | |
logits_via_embedding=False, | |
) | |
tiny_config = dataclasses.replace(tiny_config, **config_overrides) | |
vocabulary = test_utils.get_fake_vocab() | |
return models.EncoderDecoderModel( | |
module=network.Transformer(tiny_config), | |
input_vocabulary=vocabulary, | |
output_vocabulary=vocabulary, | |
optimizer_def=optimizer_def) | |
class BasicTest(chex.TestCase): | |
def get_params(cls): | |
return frozen_dict.FrozenDict({ | |
'forward': { | |
'input_layer': { | |
'embedding': jnp.zeros([16, 8], dtype=jnp.float32), | |
}, | |
'output_layer': { | |
'layer_norm': { | |
'scale': jnp.zeros([8], dtype=jnp.float32), | |
}, | |
'proj': { | |
'bias': jnp.zeros([1], dtype=jnp.float32), | |
'kernel': jnp.zeros([8, 1], dtype=jnp.float32), | |
}, | |
}, | |
}, | |
'loss': { | |
'loss_fn': { | |
'loss_biases': jnp.zeros([2], dtype=jnp.float32), | |
}, | |
}, | |
}) | |
def get_params_shapes(cls): | |
return jax.tree_map(jnp.shape, cls.get_params()) | |
def get_param_logical_axes(cls): | |
return frozen_dict.FrozenDict({ | |
'forward': { | |
'input_layer': { | |
'embedding': partitioning.PartitionSpec('vocab', 'embed'), | |
}, | |
'output_layer': { | |
'layer_norm': { | |
'scale': partitioning.PartitionSpec('embed',), | |
}, | |
'proj': { | |
'bias': | |
partitioning.PartitionSpec('output_head',), | |
'kernel': | |
partitioning.PartitionSpec('embed', 'output_head'), | |
}, | |
}, | |
}, | |
'loss': { | |
'loss_fn': { | |
'loss_biases': partitioning.PartitionSpec('unmodeled',), | |
}, | |
}, | |
}) | |
def test_logical_axes_adamw(self): | |
opt = optax.adamw(0.001, weight_decay=0.001) | |
wrapper = optimizers.OptaxWrapper(opt) | |
optimizer = wrapper.create(self.get_params()) | |
got = wrapper.derive_logical_axes(optimizer, self.get_param_logical_axes()) | |
want = optimizers.Optimizer( | |
optimizer_def=wrapper, | |
state=optimizers.OptimizerState( | |
step=None, | |
param_states=( | |
optax.ScaleByAdamState( | |
count=None, | |
mu=self.get_param_logical_axes(), | |
nu=self.get_param_logical_axes()), | |
optax.EmptyState(), | |
optax.EmptyState(), | |
)), | |
target=self.get_param_logical_axes()) | |
chex.assert_trees_all_equal(got, want) | |
def test_sanity_check_logical_axes(self, opt_name, opt_fn): | |
opt = opt_fn() | |
wrapper = optimizers.OptaxWrapper(opt) | |
optimizer = wrapper.create(self.get_params()) | |
_ = wrapper.derive_logical_axes(optimizer, self.get_param_logical_axes()) | |
# TODO(rosun): basic sanity check, we just want to make sure if a param | |
# name, e.g., `loss_biases` appear in the tree, the corresponding value is | |
# always a PartitionSpec. | |
def test_adamw_state_serialization(self): | |
opt = optax.adamw(0.001, weight_decay=0.001) | |
wrapper = optimizers.OptaxWrapper(opt) | |
optimizer = wrapper.create(self.get_params()) | |
state_dict = optimizer.state_dict() | |
chex.assert_trees_all_equal( | |
frozen_dict.FrozenDict(jax.tree_map(jnp.shape, state_dict)), | |
frozen_dict.FrozenDict({ | |
'target': self.get_params_shapes(), | |
'state': { | |
'step': (), | |
'param_states': { | |
'0': { | |
'count': (), | |
'mu': self.get_params_shapes(), | |
'nu': self.get_params_shapes(), | |
}, | |
# NB: We eliminate empty tuple leaves from EmptyState() in | |
# OptaxWrapper to avoid having the rest of T5X have to | |
# correctly handle this detail. e.g. we omit these: | |
# '1': {}, | |
# '2': {}, | |
}, | |
} | |
})) | |
new_optimizer = optimizer.restore_state(state_dict) | |
chex.assert_trees_all_equal(optimizer, new_optimizer) | |
class OptaxWrapperTest(chex.TestCase): | |
def run_train_loop(self, optimizer_def): | |
# Construct input data. | |
ds = get_fake_tokenized_dataset_no_pretokenized(split='validation') | |
ds = seqio.EncDecFeatureConverter()( | |
ds, task_feature_lengths={ | |
'inputs': 8, | |
'targets': 8 | |
}) | |
ds = ds.repeat().batch(8) | |
ds_iter = ds.as_numpy_iterator() | |
first_batch = next(ds_iter) | |
model = get_t5_test_model(optimizer_def, vocab_size=128) | |
learning_rate_fn = utils.create_learning_rate_scheduler() | |
input_shapes = jax.tree_map(jnp.shape, first_batch) | |
input_types = jax.tree_map(lambda x: jnp.dtype(x.dtype), first_batch) | |
partitioner = partitioning.PjitPartitioner( | |
num_partitions=2, | |
logical_axis_rules=partitioning.standard_logical_axis_rules()) | |
train_state_initializer = utils.TrainStateInitializer( | |
optimizer_def=model.optimizer_def, | |
init_fn=model.get_initial_variables, | |
input_shapes=input_shapes, | |
input_types=input_types, | |
partitioner=partitioner) | |
train_state_axes = train_state_initializer.train_state_axes | |
train_state = train_state_initializer.from_scratch(jax.random.PRNGKey(0)) | |
trainer_instance = trainer.Trainer( | |
model, | |
train_state=train_state, | |
partitioner=partitioner, | |
eval_names=[], | |
summary_dir=None, | |
train_state_axes=train_state_axes, | |
rng=jax.random.PRNGKey(0), | |
learning_rate_fn=learning_rate_fn, | |
num_microbatches=1) | |
chex.assert_tree_all_finite(train_state.params) | |
for _ in range(2): | |
trainer_instance.train(ds_iter, 1) | |
chex.assert_tree_all_finite(train_state.params) | |
# check save/restore structural equality | |
restored_instance = trainer_instance.train_state.restore_state( | |
trainer_instance.train_state.state_dict()) | |
chex.assert_tree_all_equal_structs(trainer_instance.train_state, | |
restored_instance) | |
# NOTE(levskaya): these are surprisingly slow tests on CPU. | |
def test_optimizer(self, opt_name, opt_fn): | |
opt = opt_fn() | |
optimizer_def = optimizers.OptaxWrapper(opt) | |
self.run_train_loop(optimizer_def) | |
if __name__ == '__main__': | |
absltest.main() | |