# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for train_state.""" from absl.testing import absltest from flax import linen as nn from flax import optim import flax.core from flax.linen import partitioning as flax_partitioning import jax import numpy as np from t5x import adafactor from t5x import optimizers from t5x import partitioning from t5x import train_state as train_state_lib mock = absltest.mock AxisMetadata = flax_partitioning.AxisMetadata FactorDim = adafactor.FactorDim class FlaxOptimTrainStateTest(absltest.TestCase): def test_init(self): model = nn.Dense(10) inputs = np.ones([2, 3], dtype=np.float32) params = model.init(jax.random.PRNGKey(0), inputs)['params'] optimizer_def = optimizers.adam(0.1) optimizer = optimizer_def.create(params) flax_mutables = flax.core.freeze({'flax_mutable1': np.ones(10)}) state = train_state_lib.FlaxOptimTrainState( optimizer, flax_mutables=flax_mutables) self.assertEqual(state.step, 0) self.assertIsInstance(state._optimizer, optimizers.Optimizer) self.assertEqual(state.state_dict()['flax_mutables'], flax.core.unfreeze(flax_mutables)) jax.tree_multimap(np.testing.assert_array_equal, params, state.params) jax.tree_multimap(np.testing.assert_array_equal, optimizer.state.param_states, state.param_states) def test_create(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'mutables': np.ones(3) }) optimizer_def = optimizers.sgd(0.42) state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, model_variables) self.assertEqual(state.step, 0) self.assertIsInstance(state._optimizer, optimizers.Optimizer) self.assertEqual(state._optimizer.optimizer_def, optimizer_def) jax.tree_multimap(np.testing.assert_array_equal, state.flax_mutables, flax.core.freeze({'mutables': np.ones(3)})) jax.tree_multimap(np.testing.assert_array_equal, state.params, model_variables['params']) self.assertIsNone(state.params_axes) def test_create_with_params_axes(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'params_axes': { 'dense': { 'bias_axes': AxisMetadata(names=('embed',)), 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), } }, }) optimizer_def = adafactor.Adafactor( 0.42, logical_factor_rules={ 'vocab': FactorDim.COLUMN, 'embed': FactorDim.ROW }) state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, model_variables) self.assertEqual(state.step, 0) self.assertIsInstance(state._optimizer, optimizers.Optimizer) self.assertEqual(state._optimizer.optimizer_def, optimizer_def) self.assertDictEqual( state._optimizer.optimizer_def.hyper_params.factor_map, { 'dense/bias': (FactorDim.NONE,), 'dense/kernel': (FactorDim.COLUMN, FactorDim.ROW) }) self.assertEqual(state.flax_mutables, flax.core.freeze({})) jax.tree_multimap(np.testing.assert_array_equal, model_variables['params'], state.params) jax.tree_multimap(np.testing.assert_array_equal, model_variables['params_axes'], state.params_axes) def test_create_with_flax_mutables_axes(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'params_axes': { 'dense': { 'bias_axes': AxisMetadata(names=('embed',)), 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), } }, 'grads': { 'dense': { 'output_grad': np.zeros(4), } }, 'grads_axes': { 'dense': { 'output_grad': AxisMetadata(names=('embed',)), } }, }) optmizer_def = adafactor.Adafactor( 0.42, logical_factor_rules={ 'vocab': FactorDim.COLUMN, 'embed': FactorDim.ROW }) state = train_state_lib.FlaxOptimTrainState.create(optmizer_def, model_variables) self.assertEqual(state.step, 0) self.assertIsInstance(state._optimizer, optimizers.Optimizer) self.assertEqual(state._optimizer.optimizer_def, optmizer_def) self.assertDictEqual( state._optimizer.optimizer_def.hyper_params.factor_map, { 'dense/bias': (FactorDim.NONE,), 'dense/kernel': (FactorDim.COLUMN, FactorDim.ROW) }) self.assertEqual(state.flax_mutables, flax.core.freeze({'grads': model_variables['grads']})) jax.tree_multimap(np.testing.assert_array_equal, model_variables['params'], state.params) jax.tree_multimap(np.testing.assert_array_equal, model_variables['params_axes'], state.params_axes) jax.tree_multimap(np.testing.assert_array_equal, model_variables['grads_axes'], state.flax_mutables_axes['grads']) def test_create_missing_params_axes(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'mutables': np.ones(3) }) with self.assertRaisesWithLiteralMatch( ValueError, 'The optimizer supports params_axes for model-based partitioning, but ' 'the model is not emitting them.'): train_state_lib.FlaxOptimTrainState.create(adafactor.Adafactor(), model_variables) def test_create_mismatched_params_axes(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'params_axes': { 'dense': { 'bias_axes': AxisMetadata(names=('embed',)), } }, 'mutables': np.ones(3) }) with self.assertRaisesWithLiteralMatch( ValueError, "Missing axis names for parameters: {'dense/kernel'}"): train_state_lib.FlaxOptimTrainState.create(adafactor.Adafactor(), model_variables) def test_replace_params(self): optimizer_def = optimizers.sgd(0.1) optimizer = optimizer_def.create({'test': np.ones(10)}) state = train_state_lib.FlaxOptimTrainState(optimizer) new_params = {'test': np.zeros(10)} new_state = state.replace_params(new_params) jax.tree_multimap(np.testing.assert_array_equal, new_params, new_state.params) expected_state_dict = state.state_dict() expected_state_dict['target'] = new_params jax.tree_multimap(np.testing.assert_array_equal, expected_state_dict, new_state.state_dict()) def test_replace_step(self): optimizer_def = optimizers.adam(0.1) optimizer = optimizer_def.create({'test': np.ones(10)}) state = train_state_lib.FlaxOptimTrainState(optimizer) self.assertEqual(state.step, 0) self.assertEqual(state.replace_step(jax.numpy.array(1)).step, 1) def test_apply_gradient(self): updated_optimizer = object() optimizer = mock.Mock( apply_gradient=mock.Mock(return_value=updated_optimizer)) state = train_state_lib.FlaxOptimTrainState(optimizer) new_flax_mutables = {'test': 44} new_state = state.apply_gradient( grads=42, learning_rate=43, flax_mutables={'test': 44}) optimizer.apply_gradient.assert_called_once_with(42, learning_rate=43) self.assertEqual(new_state._optimizer, updated_optimizer) self.assertEqual( new_state, train_state_lib.FlaxOptimTrainState( updated_optimizer, flax_mutables=new_flax_mutables)) def test_as_logical_axes(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'params_axes': { 'dense': { 'bias_axes': AxisMetadata(names=('embed',)), 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), } }, }) optimizer_def = adafactor.Adafactor( 0.42, logical_factor_rules={ 'vocab': FactorDim.COLUMN, 'embed': FactorDim.ROW }) state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, model_variables) axes_state = state.as_logical_axes() self.assertIsNone(axes_state.params_axes) jax.tree_multimap( np.testing.assert_array_equal, axes_state.params, flax.core.freeze({ 'dense': { 'bias': partitioning.PartitionSpec('embed'), 'kernel': partitioning.PartitionSpec('vocab', 'embed'), } })) def test_as_logical_axes_with_flax_mutables(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'params_axes': { 'dense': { 'bias_axes': AxisMetadata(names=('embed',)), 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), } }, 'grads': { 'dense': { 'output_grad': np.zeros(4), } }, 'grads_axes': { 'dense': { 'output_grad': AxisMetadata(names=('embed',)), } }, }) optmizer_def = adafactor.Adafactor( 0.42, logical_factor_rules={ 'vocab': FactorDim.COLUMN, 'embed': FactorDim.ROW }) state = train_state_lib.FlaxOptimTrainState.create(optmizer_def, model_variables) axes_state = state.as_logical_axes() self.assertIsNone(axes_state.params_axes) jax.tree_multimap( np.testing.assert_array_equal, axes_state.flax_mutables, flax.core.freeze({ 'grads': { 'dense': { 'output_grad': partitioning.PartitionSpec('embed'), } } })) def test_as_logical_axes_unsupported_optimizer(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'params_axes': { 'dense': { 'bias_axes': AxisMetadata(names=('embed',)), 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), } }, }) optimizer_def = optim.GradientDescent(0.42) state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, model_variables) with self.assertRaisesWithLiteralMatch( ValueError, "Optimizer 'GradientDescent' requires a `derive_logical_axes` method " 'to be used with named axis partitioning.'): state.as_logical_axes() def test_to_state_dict(self): model_variables = flax.core.freeze({ 'params': { 'kernel': np.zeros((2, 4)) }, 'params_axes': { 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), }, 'mutables': np.ones(3) }) optimizer_def = adafactor.Adafactor( 0.42, logical_factor_rules={ 'vocab': FactorDim.COLUMN, 'embed': FactorDim.ROW }) state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, model_variables) jax.tree_multimap( np.testing.assert_array_equal, state.state_dict(), { 'state': { 'step': np.array(0), 'param_states': { 'kernel': { 'm': np.zeros(1), 'v': np.zeros((2, 4)), 'v_col': np.zeros(1), 'v_row': np.zeros(1) }, } }, 'target': { 'kernel': np.zeros((2, 4)) }, 'flax_mutables': { 'mutables': np.ones(3) } }) def test_restore_state(self): model_variables = flax.core.freeze({ 'params': { 'kernel': np.zeros((2, 4)) }, 'params_axes': { 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), }, 'mutables': np.ones(3) }) optimizer_def = adafactor.Adafactor( 0.42, logical_factor_rules={ 'vocab': FactorDim.COLUMN, 'embed': FactorDim.ROW }) state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, model_variables) restored = state.restore_state({ 'state': { 'step': np.array(1), 'param_states': { 'kernel': { 'm': np.ones(1), 'v': np.ones((2, 4)), 'v_col': np.ones(1), 'v_row': np.ones(1) }, } }, 'target': { 'kernel': np.ones((2, 4)) }, 'flax_mutables': { 'mutables': np.zeros(3) } }) self.assertEqual(restored.step, 1) self.assertIsInstance(restored._optimizer, optimizers.Optimizer) self.assertEqual(restored._optimizer.optimizer_def, optimizer_def) jax.tree_multimap(np.testing.assert_array_equal, restored.flax_mutables, flax.core.freeze({'mutables': np.zeros(3)})) jax.tree_multimap(np.testing.assert_array_equal, restored.params, flax.core.freeze({'kernel': np.ones((2, 4))})) jax.tree_multimap( np.testing.assert_array_equal, restored.param_states, flax.core.freeze({ 'kernel': adafactor._AdafactorParamState( np.ones(1), np.ones(1), np.ones((2, 4)), np.ones(1)) })) jax.tree_multimap(np.testing.assert_array_equal, restored.params_axes, model_variables['params_axes']) class InferenceStateTest(absltest.TestCase): def test_init(self): model = nn.Dense(10) inputs = np.ones([2, 3], dtype=np.float32) params = model.init(jax.random.PRNGKey(0), inputs)['params'] flax_mutables = flax.core.freeze({'flax_mutable1': np.ones(10)}) state = train_state_lib.InferenceState( step=jax.numpy.array(1), params=params, flax_mutables=flax_mutables) self.assertEqual(state.step, 1) self.assertEqual(state.flax_mutables, flax.core.unfreeze(flax_mutables)) jax.tree_multimap(np.testing.assert_array_equal, params, state.params) self.assertIsNone(state.params_axes) def test_create(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'params_axes': { 'dense': { 'bias_axes': AxisMetadata(names=('embed',)), 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), } }, 'mutables': np.ones(3) }) state = train_state_lib.InferenceState.create(model_variables) self.assertEqual(state.step, 0) jax.tree_multimap(np.testing.assert_array_equal, state.flax_mutables, flax.core.freeze({'mutables': np.ones(3)})) jax.tree_multimap(np.testing.assert_array_equal, state.params, model_variables['params']) jax.tree_multimap(np.testing.assert_array_equal, state.params_axes, model_variables['params_axes']) def test_create_mismatched_params_axes(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'params_axes': { 'dense': { 'bias_axes': AxisMetadata(names=('embed',)), } }, 'mutables': np.ones(3) }) with self.assertRaisesWithLiteralMatch( ValueError, "Missing axis names for parameters: {'dense/kernel'}"): train_state_lib.InferenceState.create(model_variables) def test_replace_params(self): model_variables = flax.core.freeze({'params': {'test': np.ones(10)}}) state = train_state_lib.InferenceState.create(model_variables) new_params = {'test': np.zeros(10)} new_state = state.replace_params(new_params) jax.tree_multimap(np.testing.assert_array_equal, new_params, new_state.params) def test_replace_step(self): model_variables = flax.core.freeze({'params': {'test': np.ones(10)}}) state = train_state_lib.InferenceState.create(model_variables) self.assertEqual(state.step, 0) self.assertEqual(state.replace_step(jax.numpy.array(1)).step, 1) def test_as_logical_axes(self): model_variables = flax.core.freeze({ 'params': { 'dense': { 'bias': np.zeros(4), 'kernel': np.zeros((2, 4)) } }, 'params_axes': { 'dense': { 'bias_axes': AxisMetadata(names=('embed',)), 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), } }, }) state = train_state_lib.InferenceState.create(model_variables) axes_state = state.as_logical_axes() self.assertIsNone(axes_state.params_axes) jax.tree_multimap( np.testing.assert_array_equal, axes_state.params, flax.core.freeze({ 'dense': { 'bias': partitioning.PartitionSpec('embed'), 'kernel': partitioning.PartitionSpec('vocab', 'embed'), } })) def test_to_state_dict(self): model_variables = flax.core.freeze({ 'params': { 'bias': np.zeros(4), }, 'params_axes': { 'bias_axes': AxisMetadata(names=('embed',)), }, 'mutables': np.ones(3) }) state = train_state_lib.InferenceState.create(model_variables) jax.tree_multimap( np.testing.assert_array_equal, state.state_dict(), { 'state': { 'step': np.array(0) }, 'target': { 'bias': np.zeros(4), }, 'flax_mutables': { 'mutables': np.ones(3) } }) def test_to_state_dict_no_mutables(self): model_variables = flax.core.freeze({ 'params': { 'bias': np.zeros(4), }, 'params_axes': { 'bias_axes': AxisMetadata(names=('embed',)), }, }) state = train_state_lib.InferenceState.create(model_variables) jax.tree_multimap(np.testing.assert_array_equal, state.state_dict(), { 'state': { 'step': np.array(0) }, 'target': { 'bias': np.zeros(4), }, }) def test_restore_state(self): state = train_state_lib.InferenceState( np.array(0), {'bias': np.zeros(4)}, {'bias_axes': AxisMetadata(names=('embed',))}) state_dict = { 'state': { 'step': np.array(10) }, 'target': { 'bias': np.ones(4), }, 'flax_mutables': { 'mutables': np.ones(3) } } restored = state.restore_state(state_dict) self.assertEqual(restored.step, 10) jax.tree_multimap(np.testing.assert_array_equal, restored.flax_mutables, flax.core.freeze(state_dict['flax_mutables'])) jax.tree_multimap(np.testing.assert_array_equal, restored.params, flax.core.freeze(state_dict['target'])) self.assertEqual(restored.params_axes, {'bias_axes': AxisMetadata(names=('embed',))}) def test_restore_state_no_mutables_no_axes(self): state = train_state_lib.InferenceState(np.array(0), {}) state_dict = { 'state': { 'step': np.array(10) }, 'target': { 'bias': np.zeros(4), }, } restored = state.restore_state(state_dict) self.assertEqual(restored.step, 10) self.assertEqual(restored.flax_mutables, train_state_lib.EMPTY_DICT) jax.tree_multimap(np.testing.assert_array_equal, restored.params, flax.core.freeze(state_dict['target'])) self.assertIsNone(restored.params_axes) if __name__ == '__main__': absltest.main()