Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for train_state.""" | |
from absl.testing import absltest | |
from flax import linen as nn | |
from flax import optim | |
import flax.core | |
from flax.linen import partitioning as flax_partitioning | |
import jax | |
import numpy as np | |
from t5x import adafactor | |
from t5x import optimizers | |
from t5x import partitioning | |
from t5x import train_state as train_state_lib | |
mock = absltest.mock | |
AxisMetadata = flax_partitioning.AxisMetadata | |
FactorDim = adafactor.FactorDim | |
class FlaxOptimTrainStateTest(absltest.TestCase): | |
def test_init(self): | |
model = nn.Dense(10) | |
inputs = np.ones([2, 3], dtype=np.float32) | |
params = model.init(jax.random.PRNGKey(0), inputs)['params'] | |
optimizer_def = optimizers.adam(0.1) | |
optimizer = optimizer_def.create(params) | |
flax_mutables = flax.core.freeze({'flax_mutable1': np.ones(10)}) | |
state = train_state_lib.FlaxOptimTrainState( | |
optimizer, flax_mutables=flax_mutables) | |
self.assertEqual(state.step, 0) | |
self.assertIsInstance(state._optimizer, optimizers.Optimizer) | |
self.assertEqual(state.state_dict()['flax_mutables'], | |
flax.core.unfreeze(flax_mutables)) | |
jax.tree_multimap(np.testing.assert_array_equal, params, state.params) | |
jax.tree_multimap(np.testing.assert_array_equal, | |
optimizer.state.param_states, state.param_states) | |
def test_create(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'mutables': np.ones(3) | |
}) | |
optimizer_def = optimizers.sgd(0.42) | |
state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, | |
model_variables) | |
self.assertEqual(state.step, 0) | |
self.assertIsInstance(state._optimizer, optimizers.Optimizer) | |
self.assertEqual(state._optimizer.optimizer_def, optimizer_def) | |
jax.tree_multimap(np.testing.assert_array_equal, state.flax_mutables, | |
flax.core.freeze({'mutables': np.ones(3)})) | |
jax.tree_multimap(np.testing.assert_array_equal, state.params, | |
model_variables['params']) | |
self.assertIsNone(state.params_axes) | |
def test_create_with_params_axes(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'params_axes': { | |
'dense': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
'kernel_axes': AxisMetadata(names=('vocab', 'embed')), | |
} | |
}, | |
}) | |
optimizer_def = adafactor.Adafactor( | |
0.42, | |
logical_factor_rules={ | |
'vocab': FactorDim.COLUMN, | |
'embed': FactorDim.ROW | |
}) | |
state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, | |
model_variables) | |
self.assertEqual(state.step, 0) | |
self.assertIsInstance(state._optimizer, optimizers.Optimizer) | |
self.assertEqual(state._optimizer.optimizer_def, optimizer_def) | |
self.assertDictEqual( | |
state._optimizer.optimizer_def.hyper_params.factor_map, { | |
'dense/bias': (FactorDim.NONE,), | |
'dense/kernel': (FactorDim.COLUMN, FactorDim.ROW) | |
}) | |
self.assertEqual(state.flax_mutables, flax.core.freeze({})) | |
jax.tree_multimap(np.testing.assert_array_equal, model_variables['params'], | |
state.params) | |
jax.tree_multimap(np.testing.assert_array_equal, | |
model_variables['params_axes'], state.params_axes) | |
def test_create_with_flax_mutables_axes(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'params_axes': { | |
'dense': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
'kernel_axes': AxisMetadata(names=('vocab', 'embed')), | |
} | |
}, | |
'grads': { | |
'dense': { | |
'output_grad': np.zeros(4), | |
} | |
}, | |
'grads_axes': { | |
'dense': { | |
'output_grad': AxisMetadata(names=('embed',)), | |
} | |
}, | |
}) | |
optmizer_def = adafactor.Adafactor( | |
0.42, | |
logical_factor_rules={ | |
'vocab': FactorDim.COLUMN, | |
'embed': FactorDim.ROW | |
}) | |
state = train_state_lib.FlaxOptimTrainState.create(optmizer_def, | |
model_variables) | |
self.assertEqual(state.step, 0) | |
self.assertIsInstance(state._optimizer, optimizers.Optimizer) | |
self.assertEqual(state._optimizer.optimizer_def, optmizer_def) | |
self.assertDictEqual( | |
state._optimizer.optimizer_def.hyper_params.factor_map, { | |
'dense/bias': (FactorDim.NONE,), | |
'dense/kernel': (FactorDim.COLUMN, FactorDim.ROW) | |
}) | |
self.assertEqual(state.flax_mutables, | |
flax.core.freeze({'grads': model_variables['grads']})) | |
jax.tree_multimap(np.testing.assert_array_equal, model_variables['params'], | |
state.params) | |
jax.tree_multimap(np.testing.assert_array_equal, | |
model_variables['params_axes'], state.params_axes) | |
jax.tree_multimap(np.testing.assert_array_equal, | |
model_variables['grads_axes'], | |
state.flax_mutables_axes['grads']) | |
def test_create_missing_params_axes(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'mutables': np.ones(3) | |
}) | |
with self.assertRaisesWithLiteralMatch( | |
ValueError, | |
'The optimizer supports params_axes for model-based partitioning, but ' | |
'the model is not emitting them.'): | |
train_state_lib.FlaxOptimTrainState.create(adafactor.Adafactor(), | |
model_variables) | |
def test_create_mismatched_params_axes(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'params_axes': { | |
'dense': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
} | |
}, | |
'mutables': np.ones(3) | |
}) | |
with self.assertRaisesWithLiteralMatch( | |
ValueError, "Missing axis names for parameters: {'dense/kernel'}"): | |
train_state_lib.FlaxOptimTrainState.create(adafactor.Adafactor(), | |
model_variables) | |
def test_replace_params(self): | |
optimizer_def = optimizers.sgd(0.1) | |
optimizer = optimizer_def.create({'test': np.ones(10)}) | |
state = train_state_lib.FlaxOptimTrainState(optimizer) | |
new_params = {'test': np.zeros(10)} | |
new_state = state.replace_params(new_params) | |
jax.tree_multimap(np.testing.assert_array_equal, new_params, | |
new_state.params) | |
expected_state_dict = state.state_dict() | |
expected_state_dict['target'] = new_params | |
jax.tree_multimap(np.testing.assert_array_equal, expected_state_dict, | |
new_state.state_dict()) | |
def test_replace_step(self): | |
optimizer_def = optimizers.adam(0.1) | |
optimizer = optimizer_def.create({'test': np.ones(10)}) | |
state = train_state_lib.FlaxOptimTrainState(optimizer) | |
self.assertEqual(state.step, 0) | |
self.assertEqual(state.replace_step(jax.numpy.array(1)).step, 1) | |
def test_apply_gradient(self): | |
updated_optimizer = object() | |
optimizer = mock.Mock( | |
apply_gradient=mock.Mock(return_value=updated_optimizer)) | |
state = train_state_lib.FlaxOptimTrainState(optimizer) | |
new_flax_mutables = {'test': 44} | |
new_state = state.apply_gradient( | |
grads=42, learning_rate=43, flax_mutables={'test': 44}) | |
optimizer.apply_gradient.assert_called_once_with(42, learning_rate=43) | |
self.assertEqual(new_state._optimizer, updated_optimizer) | |
self.assertEqual( | |
new_state, | |
train_state_lib.FlaxOptimTrainState( | |
updated_optimizer, flax_mutables=new_flax_mutables)) | |
def test_as_logical_axes(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'params_axes': { | |
'dense': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
'kernel_axes': AxisMetadata(names=('vocab', 'embed')), | |
} | |
}, | |
}) | |
optimizer_def = adafactor.Adafactor( | |
0.42, | |
logical_factor_rules={ | |
'vocab': FactorDim.COLUMN, | |
'embed': FactorDim.ROW | |
}) | |
state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, | |
model_variables) | |
axes_state = state.as_logical_axes() | |
self.assertIsNone(axes_state.params_axes) | |
jax.tree_multimap( | |
np.testing.assert_array_equal, axes_state.params, | |
flax.core.freeze({ | |
'dense': { | |
'bias': partitioning.PartitionSpec('embed'), | |
'kernel': partitioning.PartitionSpec('vocab', 'embed'), | |
} | |
})) | |
def test_as_logical_axes_with_flax_mutables(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'params_axes': { | |
'dense': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
'kernel_axes': AxisMetadata(names=('vocab', 'embed')), | |
} | |
}, | |
'grads': { | |
'dense': { | |
'output_grad': np.zeros(4), | |
} | |
}, | |
'grads_axes': { | |
'dense': { | |
'output_grad': AxisMetadata(names=('embed',)), | |
} | |
}, | |
}) | |
optmizer_def = adafactor.Adafactor( | |
0.42, | |
logical_factor_rules={ | |
'vocab': FactorDim.COLUMN, | |
'embed': FactorDim.ROW | |
}) | |
state = train_state_lib.FlaxOptimTrainState.create(optmizer_def, | |
model_variables) | |
axes_state = state.as_logical_axes() | |
self.assertIsNone(axes_state.params_axes) | |
jax.tree_multimap( | |
np.testing.assert_array_equal, axes_state.flax_mutables, | |
flax.core.freeze({ | |
'grads': { | |
'dense': { | |
'output_grad': partitioning.PartitionSpec('embed'), | |
} | |
} | |
})) | |
def test_as_logical_axes_unsupported_optimizer(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'params_axes': { | |
'dense': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
'kernel_axes': AxisMetadata(names=('vocab', 'embed')), | |
} | |
}, | |
}) | |
optimizer_def = optim.GradientDescent(0.42) | |
state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, | |
model_variables) | |
with self.assertRaisesWithLiteralMatch( | |
ValueError, | |
"Optimizer 'GradientDescent' requires a `derive_logical_axes` method " | |
'to be used with named axis partitioning.'): | |
state.as_logical_axes() | |
def test_to_state_dict(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'kernel': np.zeros((2, 4)) | |
}, | |
'params_axes': { | |
'kernel_axes': AxisMetadata(names=('vocab', 'embed')), | |
}, | |
'mutables': np.ones(3) | |
}) | |
optimizer_def = adafactor.Adafactor( | |
0.42, | |
logical_factor_rules={ | |
'vocab': FactorDim.COLUMN, | |
'embed': FactorDim.ROW | |
}) | |
state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, | |
model_variables) | |
jax.tree_multimap( | |
np.testing.assert_array_equal, state.state_dict(), { | |
'state': { | |
'step': np.array(0), | |
'param_states': { | |
'kernel': { | |
'm': np.zeros(1), | |
'v': np.zeros((2, 4)), | |
'v_col': np.zeros(1), | |
'v_row': np.zeros(1) | |
}, | |
} | |
}, | |
'target': { | |
'kernel': np.zeros((2, 4)) | |
}, | |
'flax_mutables': { | |
'mutables': np.ones(3) | |
} | |
}) | |
def test_restore_state(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'kernel': np.zeros((2, 4)) | |
}, | |
'params_axes': { | |
'kernel_axes': AxisMetadata(names=('vocab', 'embed')), | |
}, | |
'mutables': np.ones(3) | |
}) | |
optimizer_def = adafactor.Adafactor( | |
0.42, | |
logical_factor_rules={ | |
'vocab': FactorDim.COLUMN, | |
'embed': FactorDim.ROW | |
}) | |
state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, | |
model_variables) | |
restored = state.restore_state({ | |
'state': { | |
'step': np.array(1), | |
'param_states': { | |
'kernel': { | |
'm': np.ones(1), | |
'v': np.ones((2, 4)), | |
'v_col': np.ones(1), | |
'v_row': np.ones(1) | |
}, | |
} | |
}, | |
'target': { | |
'kernel': np.ones((2, 4)) | |
}, | |
'flax_mutables': { | |
'mutables': np.zeros(3) | |
} | |
}) | |
self.assertEqual(restored.step, 1) | |
self.assertIsInstance(restored._optimizer, optimizers.Optimizer) | |
self.assertEqual(restored._optimizer.optimizer_def, optimizer_def) | |
jax.tree_multimap(np.testing.assert_array_equal, restored.flax_mutables, | |
flax.core.freeze({'mutables': np.zeros(3)})) | |
jax.tree_multimap(np.testing.assert_array_equal, restored.params, | |
flax.core.freeze({'kernel': np.ones((2, 4))})) | |
jax.tree_multimap( | |
np.testing.assert_array_equal, restored.param_states, | |
flax.core.freeze({ | |
'kernel': | |
adafactor._AdafactorParamState( | |
np.ones(1), np.ones(1), np.ones((2, 4)), np.ones(1)) | |
})) | |
jax.tree_multimap(np.testing.assert_array_equal, restored.params_axes, | |
model_variables['params_axes']) | |
class InferenceStateTest(absltest.TestCase): | |
def test_init(self): | |
model = nn.Dense(10) | |
inputs = np.ones([2, 3], dtype=np.float32) | |
params = model.init(jax.random.PRNGKey(0), inputs)['params'] | |
flax_mutables = flax.core.freeze({'flax_mutable1': np.ones(10)}) | |
state = train_state_lib.InferenceState( | |
step=jax.numpy.array(1), params=params, flax_mutables=flax_mutables) | |
self.assertEqual(state.step, 1) | |
self.assertEqual(state.flax_mutables, flax.core.unfreeze(flax_mutables)) | |
jax.tree_multimap(np.testing.assert_array_equal, params, state.params) | |
self.assertIsNone(state.params_axes) | |
def test_create(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'params_axes': { | |
'dense': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
'kernel_axes': AxisMetadata(names=('vocab', 'embed')), | |
} | |
}, | |
'mutables': np.ones(3) | |
}) | |
state = train_state_lib.InferenceState.create(model_variables) | |
self.assertEqual(state.step, 0) | |
jax.tree_multimap(np.testing.assert_array_equal, state.flax_mutables, | |
flax.core.freeze({'mutables': np.ones(3)})) | |
jax.tree_multimap(np.testing.assert_array_equal, state.params, | |
model_variables['params']) | |
jax.tree_multimap(np.testing.assert_array_equal, state.params_axes, | |
model_variables['params_axes']) | |
def test_create_mismatched_params_axes(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'params_axes': { | |
'dense': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
} | |
}, | |
'mutables': np.ones(3) | |
}) | |
with self.assertRaisesWithLiteralMatch( | |
ValueError, "Missing axis names for parameters: {'dense/kernel'}"): | |
train_state_lib.InferenceState.create(model_variables) | |
def test_replace_params(self): | |
model_variables = flax.core.freeze({'params': {'test': np.ones(10)}}) | |
state = train_state_lib.InferenceState.create(model_variables) | |
new_params = {'test': np.zeros(10)} | |
new_state = state.replace_params(new_params) | |
jax.tree_multimap(np.testing.assert_array_equal, new_params, | |
new_state.params) | |
def test_replace_step(self): | |
model_variables = flax.core.freeze({'params': {'test': np.ones(10)}}) | |
state = train_state_lib.InferenceState.create(model_variables) | |
self.assertEqual(state.step, 0) | |
self.assertEqual(state.replace_step(jax.numpy.array(1)).step, 1) | |
def test_as_logical_axes(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'dense': { | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
} | |
}, | |
'params_axes': { | |
'dense': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
'kernel_axes': AxisMetadata(names=('vocab', 'embed')), | |
} | |
}, | |
}) | |
state = train_state_lib.InferenceState.create(model_variables) | |
axes_state = state.as_logical_axes() | |
self.assertIsNone(axes_state.params_axes) | |
jax.tree_multimap( | |
np.testing.assert_array_equal, axes_state.params, | |
flax.core.freeze({ | |
'dense': { | |
'bias': partitioning.PartitionSpec('embed'), | |
'kernel': partitioning.PartitionSpec('vocab', 'embed'), | |
} | |
})) | |
def test_to_state_dict(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'bias': np.zeros(4), | |
}, | |
'params_axes': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
}, | |
'mutables': np.ones(3) | |
}) | |
state = train_state_lib.InferenceState.create(model_variables) | |
jax.tree_multimap( | |
np.testing.assert_array_equal, state.state_dict(), { | |
'state': { | |
'step': np.array(0) | |
}, | |
'target': { | |
'bias': np.zeros(4), | |
}, | |
'flax_mutables': { | |
'mutables': np.ones(3) | |
} | |
}) | |
def test_to_state_dict_no_mutables(self): | |
model_variables = flax.core.freeze({ | |
'params': { | |
'bias': np.zeros(4), | |
}, | |
'params_axes': { | |
'bias_axes': AxisMetadata(names=('embed',)), | |
}, | |
}) | |
state = train_state_lib.InferenceState.create(model_variables) | |
jax.tree_multimap(np.testing.assert_array_equal, state.state_dict(), { | |
'state': { | |
'step': np.array(0) | |
}, | |
'target': { | |
'bias': np.zeros(4), | |
}, | |
}) | |
def test_restore_state(self): | |
state = train_state_lib.InferenceState( | |
np.array(0), {'bias': np.zeros(4)}, | |
{'bias_axes': AxisMetadata(names=('embed',))}) | |
state_dict = { | |
'state': { | |
'step': np.array(10) | |
}, | |
'target': { | |
'bias': np.ones(4), | |
}, | |
'flax_mutables': { | |
'mutables': np.ones(3) | |
} | |
} | |
restored = state.restore_state(state_dict) | |
self.assertEqual(restored.step, 10) | |
jax.tree_multimap(np.testing.assert_array_equal, restored.flax_mutables, | |
flax.core.freeze(state_dict['flax_mutables'])) | |
jax.tree_multimap(np.testing.assert_array_equal, restored.params, | |
flax.core.freeze(state_dict['target'])) | |
self.assertEqual(restored.params_axes, | |
{'bias_axes': AxisMetadata(names=('embed',))}) | |
def test_restore_state_no_mutables_no_axes(self): | |
state = train_state_lib.InferenceState(np.array(0), {}) | |
state_dict = { | |
'state': { | |
'step': np.array(10) | |
}, | |
'target': { | |
'bias': np.zeros(4), | |
}, | |
} | |
restored = state.restore_state(state_dict) | |
self.assertEqual(restored.step, 10) | |
self.assertEqual(restored.flax_mutables, train_state_lib.EMPTY_DICT) | |
jax.tree_multimap(np.testing.assert_array_equal, restored.params, | |
flax.core.freeze(state_dict['target'])) | |
self.assertIsNone(restored.params_axes) | |
if __name__ == '__main__': | |
absltest.main() | |