Spaces:
Build error
Build error
File size: 10,491 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for t5x.optimizers."""
import dataclasses
import functools
import operator
from absl.testing import absltest
from absl.testing import parameterized
import chex
import flax
from flax.core import frozen_dict
import jax
import jax.numpy as jnp
import numpy as np
import optax
import seqio
from t5x import models
from t5x import optimizers
from t5x import partitioning
from t5x import test_utils
from t5x import trainer
from t5x import utils
from t5x.examples.t5 import network
def _assert_numpy_allclose(a, b, atol=None, rtol=None):
a, b = jnp.array(a), jnp.array(b)
a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a
b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b
kw = {}
if atol:
kw['atol'] = atol
if rtol:
kw['rtol'] = rtol
np.testing.assert_allclose(a, b, **kw)
def check_eq(xs, ys, atol=None, rtol=None):
xs_leaves, xs_tree = jax.tree_flatten(xs)
ys_leaves, ys_tree = jax.tree_flatten(ys)
assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}"
assert jax.tree_util.tree_all(
jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape,
xs_leaves, ys_leaves)), "Leaves' shapes don't match."
assert jax.tree_multimap(
functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol),
xs_leaves, ys_leaves)
def flattened_state_dict(x):
s = flax.serialization.to_state_dict(x)
return flax.traverse_util.flatten_dict(s, sep='/')
def tree_shape(x):
return jax.tree_map(jnp.shape, x)
def tree_equals(x, y):
return jax.tree_util.tree_all(jax.tree_multimap(operator.eq, x, y))
def get_fake_tokenized_dataset_no_pretokenized(*_, split='validation', **__):
return test_utils.get_fake_tokenized_dataset(split=split).map(
lambda x: {k: v for k, v in x.items() if not k.endswith('_pretokenized')})
def get_t5_test_model(optimizer_def,
**config_overrides) -> models.EncoderDecoderModel:
"""Returns a tiny T5 1.1 model to use for testing."""
tiny_config = network.T5Config(
vocab_size=128,
dtype='bfloat16',
emb_dim=8,
num_heads=4,
num_encoder_layers=2,
num_decoder_layers=2,
head_dim=3,
mlp_dim=16,
mlp_activations=('gelu', 'linear'),
dropout_rate=0.0,
logits_via_embedding=False,
)
tiny_config = dataclasses.replace(tiny_config, **config_overrides)
vocabulary = test_utils.get_fake_vocab()
return models.EncoderDecoderModel(
module=network.Transformer(tiny_config),
input_vocabulary=vocabulary,
output_vocabulary=vocabulary,
optimizer_def=optimizer_def)
class BasicTest(chex.TestCase):
@classmethod
def get_params(cls):
return frozen_dict.FrozenDict({
'forward': {
'input_layer': {
'embedding': jnp.zeros([16, 8], dtype=jnp.float32),
},
'output_layer': {
'layer_norm': {
'scale': jnp.zeros([8], dtype=jnp.float32),
},
'proj': {
'bias': jnp.zeros([1], dtype=jnp.float32),
'kernel': jnp.zeros([8, 1], dtype=jnp.float32),
},
},
},
'loss': {
'loss_fn': {
'loss_biases': jnp.zeros([2], dtype=jnp.float32),
},
},
})
@classmethod
def get_params_shapes(cls):
return jax.tree_map(jnp.shape, cls.get_params())
@classmethod
def get_param_logical_axes(cls):
return frozen_dict.FrozenDict({
'forward': {
'input_layer': {
'embedding': partitioning.PartitionSpec('vocab', 'embed'),
},
'output_layer': {
'layer_norm': {
'scale': partitioning.PartitionSpec('embed',),
},
'proj': {
'bias':
partitioning.PartitionSpec('output_head',),
'kernel':
partitioning.PartitionSpec('embed', 'output_head'),
},
},
},
'loss': {
'loss_fn': {
'loss_biases': partitioning.PartitionSpec('unmodeled',),
},
},
})
def test_logical_axes_adamw(self):
opt = optax.adamw(0.001, weight_decay=0.001)
wrapper = optimizers.OptaxWrapper(opt)
optimizer = wrapper.create(self.get_params())
got = wrapper.derive_logical_axes(optimizer, self.get_param_logical_axes())
want = optimizers.Optimizer(
optimizer_def=wrapper,
state=optimizers.OptimizerState(
step=None,
param_states=(
optax.ScaleByAdamState(
count=None,
mu=self.get_param_logical_axes(),
nu=self.get_param_logical_axes()),
optax.EmptyState(),
optax.EmptyState(),
)),
target=self.get_param_logical_axes())
chex.assert_trees_all_equal(got, want)
@parameterized.parameters(
('sgd', lambda: optax.sgd(1e-2, 0.0)),
('adam', lambda: optax.adam(1e-1)),
('adamw', lambda: optax.adamw(1e-1)),
('lamb', lambda: optax.adamw(1e-1)),
('rmsprop', lambda: optax.rmsprop(1e-1)),
('rmsprop_momentum', lambda: optax.rmsprop(5e-2, momentum=0.9)),
('fromage', lambda: optax.fromage(1e-2)),
('adabelief', lambda: optax.adabelief(1e-1)),
('radam', lambda: optax.radam(1e-1)),
('yogi', lambda: optax.yogi(1.0)),
)
def test_sanity_check_logical_axes(self, opt_name, opt_fn):
opt = opt_fn()
wrapper = optimizers.OptaxWrapper(opt)
optimizer = wrapper.create(self.get_params())
_ = wrapper.derive_logical_axes(optimizer, self.get_param_logical_axes())
# TODO(rosun): basic sanity check, we just want to make sure if a param
# name, e.g., `loss_biases` appear in the tree, the corresponding value is
# always a PartitionSpec.
def test_adamw_state_serialization(self):
opt = optax.adamw(0.001, weight_decay=0.001)
wrapper = optimizers.OptaxWrapper(opt)
optimizer = wrapper.create(self.get_params())
state_dict = optimizer.state_dict()
chex.assert_trees_all_equal(
frozen_dict.FrozenDict(jax.tree_map(jnp.shape, state_dict)),
frozen_dict.FrozenDict({
'target': self.get_params_shapes(),
'state': {
'step': (),
'param_states': {
'0': {
'count': (),
'mu': self.get_params_shapes(),
'nu': self.get_params_shapes(),
},
# NB: We eliminate empty tuple leaves from EmptyState() in
# OptaxWrapper to avoid having the rest of T5X have to
# correctly handle this detail. e.g. we omit these:
# '1': {},
# '2': {},
},
}
}))
new_optimizer = optimizer.restore_state(state_dict)
chex.assert_trees_all_equal(optimizer, new_optimizer)
class OptaxWrapperTest(chex.TestCase):
def run_train_loop(self, optimizer_def):
# Construct input data.
ds = get_fake_tokenized_dataset_no_pretokenized(split='validation')
ds = seqio.EncDecFeatureConverter()(
ds, task_feature_lengths={
'inputs': 8,
'targets': 8
})
ds = ds.repeat().batch(8)
ds_iter = ds.as_numpy_iterator()
first_batch = next(ds_iter)
model = get_t5_test_model(optimizer_def, vocab_size=128)
learning_rate_fn = utils.create_learning_rate_scheduler()
input_shapes = jax.tree_map(jnp.shape, first_batch)
input_types = jax.tree_map(lambda x: jnp.dtype(x.dtype), first_batch)
partitioner = partitioning.PjitPartitioner(
num_partitions=2,
logical_axis_rules=partitioning.standard_logical_axis_rules())
train_state_initializer = utils.TrainStateInitializer(
optimizer_def=model.optimizer_def,
init_fn=model.get_initial_variables,
input_shapes=input_shapes,
input_types=input_types,
partitioner=partitioner)
train_state_axes = train_state_initializer.train_state_axes
train_state = train_state_initializer.from_scratch(jax.random.PRNGKey(0))
trainer_instance = trainer.Trainer(
model,
train_state=train_state,
partitioner=partitioner,
eval_names=[],
summary_dir=None,
train_state_axes=train_state_axes,
rng=jax.random.PRNGKey(0),
learning_rate_fn=learning_rate_fn,
num_microbatches=1)
chex.assert_tree_all_finite(train_state.params)
for _ in range(2):
trainer_instance.train(ds_iter, 1)
chex.assert_tree_all_finite(train_state.params)
# check save/restore structural equality
restored_instance = trainer_instance.train_state.restore_state(
trainer_instance.train_state.state_dict())
chex.assert_tree_all_equal_structs(trainer_instance.train_state,
restored_instance)
# NOTE(levskaya): these are surprisingly slow tests on CPU.
@parameterized.parameters(
('sgd', lambda: optax.sgd(1e-2, 0.0)),
('adam', lambda: optax.adam(1e-1)),
('adamw', lambda: optax.adamw(1e-1)),
('lamb', lambda: optax.adamw(1e-1)),
# ('rmsprop', lambda: optax.rmsprop(1e-1)),
# ('rmsprop_momentum', lambda: optax.rmsprop(5e-2, momentum=0.9)),
# ('fromage', lambda: optax.fromage(1e-2)),
('adabelief', lambda: optax.adabelief(1e-1)),
# ('radam', lambda: optax.radam(1e-1)),
('yogi', lambda: optax.yogi(1.0)),
)
def test_optimizer(self, opt_name, opt_fn):
opt = opt_fn()
optimizer_def = optimizers.OptaxWrapper(opt)
self.run_train_loop(optimizer_def)
if __name__ == '__main__':
absltest.main()
|