Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for t5x.adafactor.""" | |
import functools | |
import operator | |
from typing import Sequence | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
import flax | |
from flax import optim # used for equivalence testing only | |
from flax import traverse_util | |
import jax | |
from jax import numpy as jnp | |
from jax import random | |
import numpy as np | |
from t5x import adafactor | |
from t5x import optimizers | |
OptimizerState = optimizers.OptimizerState | |
_AdafactorHyperParams = adafactor._AdafactorHyperParams | |
_AdafactorParamState = adafactor._AdafactorParamState | |
_BATCH = adafactor.FactorDim.BATCH | |
_ROW = adafactor.FactorDim.ROW | |
_COL = adafactor.FactorDim.COLUMN | |
# Testing helpers | |
def _assert_numpy_allclose(a, b, atol=None, rtol=None): | |
a, b = jnp.array(a), jnp.array(b) | |
a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a | |
b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b | |
kw = {} | |
if atol: | |
kw['atol'] = atol | |
if rtol: | |
kw['rtol'] = rtol | |
np.testing.assert_allclose(a, b, **kw) | |
def check_eq(xs, ys, atol=None, rtol=None): | |
xs_leaves, xs_tree = jax.tree_flatten(xs) | |
ys_leaves, ys_tree = jax.tree_flatten(ys) | |
assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}" | |
assert jax.tree_util.tree_all( | |
jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape, | |
xs_leaves, ys_leaves)), "Leaves' shapes don't match." | |
assert jax.tree_multimap( | |
functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol), | |
xs_leaves, ys_leaves) | |
def flattened_state_dict(x): | |
s = flax.serialization.to_state_dict(x) | |
return flax.traverse_util.flatten_dict(s, sep='/') | |
def tree_shape(x): | |
return jax.tree_map(jnp.shape, x) | |
def tree_equals(x, y): | |
return jax.tree_util.tree_all(jax.tree_multimap(operator.eq, x, y)) | |
def _get_multi_adafactor( | |
learning_rate: float, step_offset: int, | |
adafactor_exclude_from_parameter_scale: Sequence[str] | |
) -> optim.MultiOptimizer: | |
"""Get adafactor with support for excluding some parameters from scaling.""" | |
def _should_not_scale(path): | |
return any([s in path for s in adafactor_exclude_from_parameter_scale]) | |
scaled_vars = traverse_util.ModelParamTraversal( | |
lambda path, _: not _should_not_scale(path)) | |
unscaled_vars = traverse_util.ModelParamTraversal( | |
lambda path, _: _should_not_scale(path)) | |
scaled_opt = optim.Adafactor( | |
learning_rate, decay_rate=0.8, step_offset=step_offset) | |
unscaled_opt = optim.Adafactor( | |
learning_rate, | |
decay_rate=0.8, | |
step_offset=step_offset, | |
multiply_by_parameter_scale=False) | |
return optim.MultiOptimizer((scaled_vars, scaled_opt), | |
(unscaled_vars, unscaled_opt)) | |
# Inline test data | |
MODEL_SHAPE = { | |
'decoder': { | |
'decoder_norm': {'scale': [128]}, | |
'layers_0': { | |
'encoder_decoder_attention': { | |
'key': {'kernel': [128, 256]}, | |
'out': {'kernel': [256, 128]}, | |
'query': {'kernel': [128, 256]}, | |
'value': {'kernel': [128, 256]}}, | |
'mlp': { | |
'wi': {'kernel': [128, 512]}, | |
'wo': {'kernel': [512, 128]}}, | |
'pre_cross_attention_layer_norm': {'scale': [128]}, | |
'pre_mlp_layer_norm': {'scale': [128]}, | |
'pre_self_attention_layer_norm': {'scale': [128]}, | |
'self_attention': { | |
'key': {'kernel': [128, 256]}, | |
'out': {'kernel': [256, 128]}, | |
'query': {'kernel': [128, 256]}, | |
'value': {'kernel': [128, 256]}}}, | |
'layers_1': { | |
'encoder_decoder_attention': { | |
'key': {'kernel': [128, 128]}, | |
'out': {'kernel': [128, 128]}, | |
'query': {'kernel': [128, 128]}, | |
'value': {'kernel': [128, 128]}}, | |
'mlp': { | |
'wi': {'kernel': [128, 512]}, | |
'wo': {'kernel': [512, 128]}}, | |
'pre_cross_attention_layer_norm': {'scale': [128]}, | |
'pre_mlp_layer_norm': {'scale': [128]}, | |
'pre_self_attention_layer_norm': {'scale': [128]}, | |
'self_attention': { | |
'key': {'kernel': [128, 256]}, | |
'out': {'kernel': [256, 128]}, | |
'query': {'kernel': [128, 256]}, | |
'value': {'kernel': [128, 256]}}}, | |
'relpos_bias': {'rel_embedding': [2, 32]}}, | |
'encoder': { | |
'encoder_norm': {'scale': [128]}, | |
'layers_0': { | |
'attention': { | |
'key': {'kernel': [128, 256]}, | |
'out': {'kernel': [256, 128]}, | |
'query': {'kernel': [128, 256]}, | |
'value': {'kernel': [128, 256]}}, | |
'mlp': { | |
'wi': {'kernel': [128, 512]}, | |
'wo': {'kernel': [512, 128]}}, | |
'pre_attention_layer_norm': {'scale': [128]}, | |
'pre_mlp_layer_norm': {'scale': [128]}}, | |
'layers_1': { | |
'attention': { | |
'key': {'kernel': [128, 256]}, | |
'out': {'kernel': [256, 128]}, | |
'query': {'kernel': [128, 256]}, | |
'value': {'kernel': [128, 256]}}, | |
'mlp': { | |
'wi': {'kernel': [128, 512]}, | |
'wo': {'kernel': [512, 128]}}, | |
'pre_attention_layer_norm': {'scale': [128]}, | |
'pre_mlp_layer_norm': {'scale': [128]}}, | |
'relpos_bias': {'rel_embedding': [2, 32]}}, | |
'token_embedder': {'embedding': [32128, 128]}} # pyformat: disable | |
class AdafactorTest(parameterized.TestCase): | |
# Classic Adafactor Behavior Tests | |
def test_2D_simple(self): | |
x = {'a': jnp.ones((24, 16))} | |
opt_def = adafactor.Adafactor(min_dim_size_to_factor=8) | |
optimizer = opt_def.create(x) | |
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) | |
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)} | |
self.assertTrue(tree_equals(shapes, ref)) | |
def test_2D_simple_nofactor(self): | |
x = {'a': jnp.ones((24, 16))} | |
opt_def = adafactor.Adafactor(min_dim_size_to_factor=32) | |
optimizer = opt_def.create(x) | |
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) | |
ref = {'a/m': (1,), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)} | |
self.assertTrue(tree_equals(shapes, ref)) | |
def test_2D_simple_nofactor_momentum(self): | |
x = {'a': jnp.ones((24, 16))} | |
opt_def = adafactor.Adafactor(min_dim_size_to_factor=32, beta1=0.1) | |
optimizer = opt_def.create(x) | |
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) | |
ref = {'a/m': (24, 16), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)} | |
self.assertTrue(tree_equals(shapes, ref)) | |
def test_3D_simple(self): | |
x = {'a': jnp.ones((24, 4, 16))} | |
factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),)) | |
opt_def = adafactor.Adafactor( | |
min_dim_size_to_factor=8, factor_map=factor_map) | |
optimizer = opt_def.create(x) | |
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) | |
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)} | |
self.assertTrue(tree_equals(shapes, ref)) | |
def test_init_state(self): | |
params = {'x': np.zeros((3, 2))} | |
optimizer_def = adafactor.Adafactor( | |
learning_rate=0.1, decay_rate=0.8, beta1=None, min_dim_size_to_factor=0) | |
state = optimizer_def.init_state(params) | |
expected_hyper_params = _AdafactorHyperParams(0.1, True, True, None, 0.8, 0, | |
1.0, None, 0, 1e-30, 1e-3) | |
self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) | |
expected_state = OptimizerState( | |
0, { | |
'x': | |
_AdafactorParamState( | |
np.zeros((2,)), np.zeros((3,)), np.zeros( | |
(1,)), np.zeros((1,))) | |
}) | |
check_eq(state, expected_state) | |
# unfactorized | |
optimizer_def = adafactor.Adafactor( | |
learning_rate=0.1, decay_rate=0.8, beta1=0.0, min_dim_size_to_factor=32) | |
state = optimizer_def.init_state(params) | |
expected_hyper_params = _AdafactorHyperParams(0.1, True, True, 0.0, 0.8, 0, | |
1.0, None, 32, 1e-30, 1e-3) | |
self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) | |
expected_state = OptimizerState( | |
0, { | |
'x': | |
_AdafactorParamState( | |
np.zeros((1,)), np.zeros((1,)), np.zeros( | |
(3, 2)), np.zeros((3, 2))) | |
}) | |
check_eq(state, expected_state) | |
def test_apply_gradient(self): | |
optimizer_def = adafactor.Adafactor( | |
learning_rate=0.1, decay_rate=0.8, min_dim_size_to_factor=0) | |
params = {'x': np.ones((3, 2), np.float32)} | |
state = OptimizerState( | |
1, { | |
'x': | |
_AdafactorParamState( | |
np.array([0.9, 0.9]), np.array([0.1, 0.1, 0.1]), | |
np.zeros((1,)), np.zeros((1,))) | |
}) | |
grads = {'x': np.ones((3, 2), np.float32)} | |
new_params, new_state = optimizer_def.apply_gradient( | |
optimizer_def.hyper_params, params, state, grads) | |
expected_new_state = OptimizerState( | |
2, { | |
'x': | |
_AdafactorParamState( | |
np.array([0.9574349, 0.9574349]), | |
np.array([0.6169143, 0.6169143, 0.6169143]), np.zeros( | |
(1,)), np.zeros((1,))) | |
}) | |
expected_new_params = {'x': 0.9 * np.ones((3, 2))} | |
check_eq(new_params, expected_new_params) | |
check_eq(new_state, expected_new_state, rtol=1e-6) | |
# unfactored w momentum | |
optimizer_def = adafactor.Adafactor( | |
learning_rate=0.1, beta1=0.0, decay_rate=0.8, min_dim_size_to_factor=32) | |
params = {'x': np.ones((3, 2), np.float32)} | |
state = OptimizerState( | |
1, { | |
'x': | |
_AdafactorParamState( | |
np.zeros(1,), np.zeros(1,), 0.5 * np.ones( | |
(3, 2)), np.zeros((3, 2))) | |
}) | |
grads = {'x': np.ones((3, 2), np.float32)} | |
new_params, new_state = optimizer_def.apply_gradient( | |
optimizer_def.hyper_params, params, state, grads) | |
expected_new_params = {'x': 0.9 * np.ones((3, 2))} | |
check_eq(new_params, expected_new_params) | |
expected_new_state = OptimizerState( | |
2, { | |
'x': | |
_AdafactorParamState( | |
np.array([0.0]), np.array([0.0]), 0.787174 * np.ones( | |
(3, 2)), 0.1 * np.ones((3, 2))) | |
}) | |
check_eq(new_state, expected_new_state, rtol=1e-6) | |
def test_apply_gradient_with_global_norm_clipping(self): | |
optimizer_def = adafactor.Adafactor( | |
learning_rate=0.1, | |
decay_rate=0.8, | |
min_dim_size_to_factor=0, | |
global_norm_clip_threshold=1.0) | |
params = {'x': np.ones((3, 2), np.float32)} | |
state = OptimizerState( | |
1, { | |
'x': | |
_AdafactorParamState( | |
np.array([0.9, 0.9]), np.array([0.1, 0.1, 0.1]), | |
np.zeros((1,)), np.zeros((1,))) | |
}) | |
grads = {'x': np.ones((3, 2), np.float32)} | |
new_params, new_state = optimizer_def.apply_gradient( | |
optimizer_def.hyper_params, params, state, grads) | |
expected_new_state = OptimizerState( | |
2, { | |
'x': | |
_AdafactorParamState( | |
np.array([0.478811, 0.478811]), | |
np.array([0.13829, 0.13829, 0.13829]), np.zeros( | |
(1,)), np.zeros((1,))) | |
}) | |
expected_new_params = {'x': 0.9 * np.ones((3, 2))} | |
check_eq(new_params, expected_new_params) | |
check_eq(new_state, expected_new_state, rtol=1e-6) | |
def test_factorizes(self): | |
params = {'x': np.zeros((64, 64))} | |
optimizer_def = adafactor.Adafactor( | |
learning_rate=0.1, | |
decay_rate=0.8, | |
beta1=None, | |
min_dim_size_to_factor=32) | |
state = optimizer_def.init_state(params) | |
self.assertEqual(state.param_states['x'].v.shape, (1,)) | |
self.assertEqual(state.param_states['x'].m.shape, (1,)) | |
self.assertEqual(state.param_states['x'].v_row.shape, (64,)) | |
self.assertEqual(state.param_states['x'].v_col.shape, (64,)) | |
params = {'x': np.zeros((31, 64))} | |
optimizer_def = adafactor.Adafactor( | |
learning_rate=0.1, | |
decay_rate=0.8, | |
beta1=None, | |
min_dim_size_to_factor=32) | |
state = optimizer_def.init_state(params) | |
self.assertEqual(state.param_states['x'].v.shape, (31, 64)) | |
self.assertEqual(state.param_states['x'].m.shape, (1,)) | |
self.assertEqual(state.param_states['x'].v_row.shape, (1,)) | |
self.assertEqual(state.param_states['x'].v_col.shape, (1,)) | |
# Manually specified factorization rules tests. | |
def test_2D_ignore_specified_factor_rule(self, rule): | |
x = {'a': jnp.ones((24, 16))} | |
factor_map = adafactor.HParamMap((('a', rule),)) | |
opt_def = adafactor.Adafactor( | |
min_dim_size_to_factor=8, factor_map=factor_map) | |
optimizer = opt_def.create(x) | |
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) | |
# Since param is 2D, the explicit factor rule should be ignored and falls | |
# back to heuristics where v_row corresponds to the smaller dim. | |
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)} | |
self.assertTrue(tree_equals(shapes, ref)) | |
def test_3D_simple_manual_rules(self): | |
x = {'a': jnp.ones((24, 4, 16))} | |
factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),)) | |
opt_def = adafactor.Adafactor( | |
min_dim_size_to_factor=8, factor_map=factor_map) | |
optimizer = opt_def.create(x) | |
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) | |
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)} | |
self.assertTrue(tree_equals(shapes, ref)) | |
factor_map = adafactor.HParamMap((('a', (_ROW, _BATCH, _COL)),)) | |
opt_def = adafactor.Adafactor( | |
min_dim_size_to_factor=8, factor_map=factor_map) | |
optimizer = opt_def.create(x) | |
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) | |
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (4, 16), 'a/v_row': (24, 4)} | |
self.assertTrue(tree_equals(shapes, ref)) | |
factor_map = adafactor.HParamMap((('a', (_COL, _ROW, _ROW)),)) | |
opt_def = adafactor.Adafactor( | |
min_dim_size_to_factor=8, factor_map=factor_map) | |
optimizer = opt_def.create(x) | |
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) | |
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (4, 16)} | |
self.assertTrue(tree_equals(shapes, ref)) | |
factor_map = adafactor.HParamMap((('a', (_COL, _COL, _ROW)),)) | |
opt_def = adafactor.Adafactor( | |
min_dim_size_to_factor=8, factor_map=factor_map) | |
optimizer = opt_def.create(x) | |
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) | |
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (16,)} | |
self.assertTrue(tree_equals(shapes, ref)) | |
def test_standard_factor_rules(self): | |
# one-off test to double-check that we're following the previous | |
# heuristic convention for rows/columns. | |
def test_standard_factor_rules(): | |
token_embedding = (_COL, _ROW) | |
attn_qkv = (_ROW, _COL) | |
attn_out = (_COL, _ROW) | |
mlp_in = (_ROW, _COL) | |
mlp_out = (_COL, _ROW) | |
return ((r'_layer_norm/(bias|scale)', | |
None), (r'(encoder|decoder)_norm/(bias|scale)', None), | |
(r'(encoder_decoder_|self_|\b)attention/(query|key|value)/kernel', | |
attn_qkv), (r'(encoder_decoder_|self_|\b)attention/out/kernel', | |
attn_out), (r'mlp/DenseGeneral_\d+/bias', None), | |
(r'mlp/wi(_\d+)?/kernel', mlp_in), (r'mlp/wo/kernel', mlp_out), | |
(r'\brelpos_bias', None), (r'token_embedder', token_embedding), | |
(r'.*', adafactor.HEURISTIC_RULE)) | |
# create fake model parameters | |
k = jax.random.PRNGKey(0) | |
params = jax.tree_map( | |
lambda shape: jax.random.uniform(k, shape), | |
MODEL_SHAPE, | |
is_leaf=lambda x: isinstance(x, list)) | |
# make traditional adafactor state with heuristic | |
factor_map1 = adafactor.HParamMap(((r'.*', adafactor.HEURISTIC_RULE),)) | |
optimizer_def1 = adafactor.Adafactor( | |
0.1, | |
decay_rate=0.8, | |
step_offset=0, | |
multiply_by_parameter_scale=True, | |
factor_map=factor_map1) | |
optimizer1 = optimizer_def1.create(params) | |
# make traditional adafactor state with explicit rules | |
factor_map2 = adafactor.HParamMap(test_standard_factor_rules()) | |
optimizer_def2 = adafactor.Adafactor( | |
0.1, | |
decay_rate=0.8, | |
step_offset=0, | |
multiply_by_parameter_scale=True, | |
factor_map=factor_map2) | |
optimizer2 = optimizer_def2.create(params) | |
# are they the same? | |
check_eq(optimizer1.state.param_states, optimizer2.state.param_states) | |
def test_no_factor_map_equivalence(self, shape): | |
k = random.PRNGKey(0) | |
k1, k2 = random.split(k) | |
p = {'a': random.uniform(k1, shape)} | |
g = {'a': random.uniform(k2, shape)} | |
orig_opt = optim.Adafactor(0.1).create(p) | |
new_opt = adafactor.Adafactor(0.1, factor_map=None).create(p) | |
check_eq(orig_opt.state_dict(), new_opt.state_dict()) | |
orig_opt1 = orig_opt.apply_gradient(g) | |
new_opt1 = new_opt.apply_gradient(g) | |
check_eq(orig_opt1.state_dict(), new_opt1.state_dict()) | |
def test_simple_equivalence(self, shape, rule): | |
k = random.PRNGKey(0) | |
k1, k2 = random.split(k) | |
k3, k4 = random.split(k1) | |
k5, k6 = random.split(k2) | |
p = {'a': random.uniform(k3, shape), 'b': random.uniform(k4, shape)} | |
g = {'a': random.uniform(k5, shape), 'b': random.uniform(k6, shape)} | |
orig_opt = optim.Adafactor(0.1).create(p) | |
factor_map = adafactor.HParamMap( | |
rules=((('a'), rule), ('.*', adafactor.HEURISTIC_RULE))) | |
new_opt = adafactor.Adafactor(0.1, factor_map=factor_map).create(p) | |
check_eq(orig_opt.state_dict(), new_opt.state_dict()) | |
orig_opt1 = orig_opt.apply_gradient(g) | |
new_opt1 = new_opt.apply_gradient(g) | |
check_eq(orig_opt1.state_dict(), new_opt1.state_dict()) | |
def test_multiply_by_parameter_scale_equivalence(self, shape): | |
# Use large parameter values to magnify the parameter scaling effect. | |
p = {'a': np.random.randn(*shape) * 100, 'b': np.random.randn(*shape) * 100} | |
g = {'a': np.random.randn(*shape), 'b': np.random.randn(*shape)} | |
orig_opt = _get_multi_adafactor( | |
3.0, 0, adafactor_exclude_from_parameter_scale=('a',)).create(p) | |
scaling_map = adafactor.HParamMap([('a', False), ('.*', True)]) | |
new_opt = adafactor.Adafactor( | |
3.0, multiply_by_parameter_scale=scaling_map).create(p) | |
check_eq(orig_opt.state_dict(), new_opt.state_dict()) | |
orig_opt1 = orig_opt.apply_gradient(g) | |
new_opt1 = new_opt.apply_gradient(g) | |
check_eq(orig_opt1.state_dict(), new_opt1.state_dict()) | |
def test_3d_without_factor_map(self): | |
x = {'a': jnp.ones((24, 4, 16))} | |
opt_def = adafactor.Adafactor(factor_map=None) | |
with self.assertRaises(ValueError): | |
_ = opt_def.create(x) | |
if __name__ == '__main__': | |
absltest.main() | |