# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for trainer.""" import contextlib from absl.testing import absltest from flax import optim import jax import numpy as np from t5x import metrics as metrics_lib from t5x import models as models_lib from t5x import train_state as train_state_lib from t5x.contrib.moe import partitioning from t5x.contrib.moe import trainer as trainer_lib import tensorflow as tf mock = absltest.mock jax.config.parse_flags_with_absl() # Make `log_elapsed_time` a no-op to simplify mocking of `time.time()`. @contextlib.contextmanager def fake_log_elapsed_time(_): yield jax._src.dispatch.log_elapsed_time = fake_log_elapsed_time def fake_accum_grads(model, optimizer, batch, rng, num_microbatches, data_partition_spec): del model, num_microbatches, rng, data_partition_spec # Add `i` to each optimzer value. i = batch['i'].sum() grad_accum = jax.tree_map(lambda x: i, optimizer) # Add j to each metric. j = batch['j'].sum() metrics = { 'loss': metrics_lib.Sum.from_model_output(j), 'accuracy': metrics_lib.Sum.from_model_output(j) } return grad_accum, metrics, None def fake_apply_grads(optimizer, grad_accum, metrics, learning_rate, weight_metrics_computer, other_state_variables=None): del weight_metrics_computer del other_state_variables metrics['learning_rate'] = metrics_lib.Sum.from_model_output(learning_rate) optimizer = jax.tree_multimap(lambda x, g: x + g, optimizer, grad_accum) return optimizer, metrics class MoeTrainerTest(absltest.TestCase): def setUp(self): super().setUp() self.init_optimizer = optim.Optimizer( optim.GradientDescent(), state=optim.OptimizerState( step=0, param_states={ 'expert_bias': 0, 'kernel': 0 }), target={ 'expert_bias': np.zeros(4), 'kernel': np.zeros((2, 4)) }) self.init_train_state = train_state_lib.FlaxOptimTrainState( self.init_optimizer) train_state_axes = jax.tree_map(lambda x: None, self.init_train_state) model_dir = self.create_tempdir().full_path mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} self.dataset = tf.data.Dataset.range(6).map(mapfn).batch( 2, drop_remainder=True) num_experts = 10 self.test_trainer = trainer_lib.MoeTrainer( model=mock.create_autospec(models_lib.BaseModel, instance=True), train_state=self.init_train_state, partitioner=partitioning.MoePjitPartitioner( num_experts=num_experts, num_partitions=1), eval_names=['task1', 'task2'], summary_dir=model_dir, train_state_axes=train_state_axes, rng=np.ones(2, np.uint32), learning_rate_fn=lambda step: 2 * step, num_microbatches=None, num_experts=num_experts) @mock.patch('time.time') @mock.patch('t5x.trainer.accumulate_grads_microbatched', fake_accum_grads) @mock.patch('t5x.trainer.apply_grads', fake_apply_grads) @mock.patch('absl.logging.log', lambda *_: None) # avoids time.time() calls def _test_train(self, precompile, mock_time=None): trainer = self.test_trainer initial_rng = trainer._base_rng if precompile: mock_time.side_effect = [0, 1] trainer.compile_train(next(self.dataset.as_numpy_iterator())) trainer._compiled_train_step = mock.Mock( side_effect=trainer._compiled_train_step) trainer._partitioned_train_step = mock.Mock( side_effect=trainer._partitioned_train_step) # train start, logging, train end, logging mock_time.side_effect = [1, 5] num_steps = 2 trainer.train(self.dataset.as_numpy_iterator(), num_steps) # Base rng must remain the same. np.testing.assert_array_equal(trainer._base_rng, initial_rng) expected_optimizer = optim.Optimizer( self.init_optimizer.optimizer_def, state=optim.OptimizerState( step=[6], param_states={ 'expert_bias': 60, # 10 * (0+1+2+3) = 60 'kernel': 6 # 0+1+2+3 = 6 }), target={ 'expert_bias': 60 * np.ones(4), 'kernel': 6 * np.ones((2, 4)) }) expected_train_state = train_state_lib.FlaxOptimTrainState( expected_optimizer) jax.tree_multimap(np.testing.assert_allclose, trainer.train_state, expected_train_state) if precompile: self.assertEqual(trainer._compiled_train_step.call_count, num_steps) trainer._partitioned_train_step.assert_not_called() else: self.assertIsNone(trainer._compiled_train_step) self.assertEqual(trainer._partitioned_train_step.call_count, num_steps) def test_train_noprecompile(self): self._test_train(False) def test_train_precompile(self): self._test_train(True) if __name__ == '__main__': absltest.main()