Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for t5x.trainer_lib.""" | |
import collections | |
import contextlib | |
import os | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
import chex | |
from clu import metric_writers | |
import clu.metrics | |
import clu.values | |
import flax | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from t5x import metrics as metrics_lib | |
from t5x import models as models_lib | |
from t5x import optimizers | |
from t5x import partitioning | |
from t5x import test_utils | |
from t5x import train_state as train_state_lib | |
from t5x import trainer as trainer_lib | |
import tensorflow as tf | |
from tensorflow.io import gfile | |
mock = absltest.mock | |
jax.config.parse_flags_with_absl() | |
# Make `log_elapsed_time` a no-op to simplify mocking of `time.time()`. | |
def fake_log_elapsed_time(_): | |
yield | |
jax._src.dispatch.log_elapsed_time = fake_log_elapsed_time | |
def _validate_events(test_case, summary_dir, expected_metrics, steps): | |
summaries = gfile.listdir(summary_dir) | |
test_case.assertLen(summaries, 1) | |
summary_path = os.path.join(summary_dir, summaries[0]) | |
event_file = os.path.join(summary_path) | |
events = list(tf.compat.v1.train.summary_iterator(event_file)) | |
actual_events = {} | |
# First event is boilerplate | |
test_case.assertLen(events, len(steps) + 1) | |
for step, event in zip(steps, events[1:]): | |
test_case.assertEqual(event.step, step) | |
test_case.assertLen(event.summary.value, 1) | |
tensor = event.summary.value[0].tensor | |
if tensor.string_val: | |
actual_events[event.summary.value[0].tag] = tensor.string_val[0].decode() | |
else: | |
actual_events[event.summary.value[0].tag] = float(tf.make_ndarray(tensor)) | |
jax.tree_multimap(test_case.assertAlmostEqual, actual_events, | |
expected_metrics) | |
class MetricsManagerTest(absltest.TestCase): | |
def setUp(self): | |
super().setUp() | |
self.model_dir = self.create_tempdir().full_path | |
def test_summary_dir(self): | |
# All hosts have the summary dir. | |
with mock.patch('jax.process_index', return_value=0): | |
mm = trainer_lib.MetricsManager('eval', self.model_dir) | |
self.assertEqual(mm.summary_dir, os.path.join(self.model_dir, 'eval')) | |
mm.close() | |
with mock.patch('jax.process_index', return_value=1): | |
mm = trainer_lib.MetricsManager('eval', self.model_dir) | |
self.assertEqual(mm.summary_dir, os.path.join(self.model_dir, 'eval')) | |
mm.close() | |
def test_summary_writer(self): | |
# Only host 0 creates a non-empty summary writer. | |
with mock.patch('jax.process_index', return_value=1): | |
mm = trainer_lib.MetricsManager('eval', self.model_dir) | |
self.assertFalse(gfile.exists(mm.summary_dir)) | |
mm.close() | |
with mock.patch('jax.process_index', return_value=0): | |
mm = trainer_lib.MetricsManager('eval', self.model_dir) | |
self.assertIsInstance(mm.summary_writer, metric_writers.MetricWriter) | |
self.assertTrue(gfile.exists(mm.summary_dir)) | |
mm.close() | |
def test_write_scalar(self): | |
gfile.makedirs(os.path.join(self.model_dir, 'eval')) | |
# tag, value, step | |
scalars = [('loss', 1.0, 1), ('accuracy', 100.0, 2)] | |
# Only host 0 has actually writes summaries. | |
with mock.patch('jax.process_index', return_value=1): | |
mm = trainer_lib.MetricsManager('eval', self.model_dir) | |
for s in scalars: | |
mm.write_scalar(*s) | |
self.assertEmpty(gfile.listdir(mm.summary_dir)) | |
mm.close() | |
with mock.patch('jax.process_index', return_value=0): | |
mm = trainer_lib.MetricsManager('eval', self.model_dir) | |
for s in scalars: | |
mm.write_scalar(*s) | |
mm.flush() | |
summaries = gfile.listdir(mm.summary_dir) | |
self.assertLen(summaries, 1) | |
event_file = os.path.join(mm.summary_dir, summaries[0]) | |
events = list(tf.compat.v1.train.summary_iterator(event_file)) | |
# First event is boilerplate | |
self.assertLen(events, 3) | |
for event, (tag, value, step) in zip(events[1:], scalars): | |
self.assertEqual(event.step, step) | |
self.assertLen(event.summary.value, 1) | |
self.assertEqual(event.summary.value[0].tag, tag) | |
self.assertEqual(tf.make_ndarray(event.summary.value[0].tensor), value) | |
mm.close() | |
def test_write_metrics_summary(self): | |
gfile.makedirs(os.path.join(self.model_dir, 'eval')) | |
class MockTextMetric(clu.metrics.Metric): | |
def compute_value(self): | |
return clu.values.Text('test metric') | |
accumulated_metrics = { | |
'loss': metrics_lib.Sum(40.0), | |
'accuracy': metrics_lib.AveragePerStep.from_model_output(20.0), | |
'steps_per_second': metrics_lib.StepsPerTime(), | |
'text': MockTextMetric() | |
} | |
expected_values = { | |
'loss': clu.values.Scalar(40.0), | |
'accuracy': clu.values.Scalar(10.0), | |
'steps_per_second': clu.values.Scalar(0.05), | |
'text': clu.values.Text('test metric') | |
} | |
with mock.patch( | |
'jax.process_index', return_value=0), mock.patch( | |
'time.time', | |
side_effect=[0, 40] # start_time, end_time | |
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() | |
mm = trainer_lib.MetricsManager('eval', summary_dir=self.model_dir) | |
mm.start_duration_timer() | |
summary = mm.write_metrics_summary( | |
accumulated_metrics, step=4, num_steps=2) | |
mm.flush() | |
self.assertDictEqual(summary.result(), expected_values) | |
_validate_events( | |
self, | |
mm.summary_dir, {k: v.value for k, v in expected_values.items()}, | |
steps=[4, 4, 4, 4]) | |
mm.close() | |
def test_timer_blocking_on_donated_buffer(self): | |
mm = trainer_lib.MetricsManager('train', summary_dir=None) | |
x = jnp.zeros(1) | |
# Not deleted. | |
mm.start_duration_timer(block_on=x) | |
mm._duration_timer._start_future.result() | |
# Deleted/donated. | |
x.device_buffer.delete() | |
mm.start_duration_timer(block_on=x) | |
mm._duration_timer._start_future.result() | |
def test_timer_concurrency(self): | |
mm = trainer_lib.MetricsManager('train') | |
n = 10 | |
with mock.patch( | |
'time.time', | |
side_effect=range(2 * n) # start_time, end_time | |
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() | |
for _ in range(n): | |
mm.start_duration_timer() | |
summary = mm.write_metrics_summary({'time': metrics_lib.Time()}, 0, 1) | |
self.assertEqual(1, summary.result()['time'].value) | |
mm.flush() | |
def fake_accum_grads(model, optimizer, batch, rng, num_microbatches, | |
data_partition_spec): | |
del model, num_microbatches, rng, data_partition_spec | |
# Add `i` to each optimzer value. | |
i = batch['i'].sum() | |
grad_accum = jax.tree_map(lambda x: i, optimizer) | |
# Add j to each metric. | |
j = batch['j'].sum() | |
metrics = {'loss': metrics_lib.Sum(j), 'accuracy': metrics_lib.Sum(j)} | |
return grad_accum, metrics, None | |
def fake_apply_grads(optimizer, | |
grad_accum, | |
metrics, | |
learning_rate, | |
weight_metrics_computer, | |
other_state_variables=None): | |
del weight_metrics_computer | |
del other_state_variables | |
metrics['learning_rate'] = clu.metrics.Average(learning_rate, count=1) | |
optimizer = jax.tree_multimap(lambda x, g: x + g, optimizer, grad_accum) | |
return optimizer, metrics | |
def fake_eval_step(model, optimizer, batch): | |
del model, optimizer | |
# Add `i` to each metric. | |
i = batch['i'].sum() | |
return {'loss': metrics_lib.Sum(i), 'accuracy': metrics_lib.Sum(i)} | |
def fake_eval_fn_without_weight_sum(params, batch): | |
del params | |
# Add `i` to each metric. | |
i = batch['i'].sum() | |
loss = metrics_lib.Sum(i) | |
return loss, {'loss': loss, 'accuracy': metrics_lib.Sum(i)} | |
def fake_value_and_grad_fn_without_weight_sum(callable_fn, has_aux=False): | |
del callable_fn, has_aux | |
def fake_grad_fn_without_weight_sum(train_state_params, | |
batch, | |
dropout_rng, | |
flax_mutables=None): | |
del dropout_rng, train_state_params, flax_mutables | |
# Add `i` to each optimzer value. | |
i = batch['i'].sum() | |
optimizer = optimizers.Optimizer( | |
optimizers.sgd(0.1), | |
state=optimizers.OptimizerState( | |
step=0, param_states={ | |
'bias': 0, | |
'kernel': 0 | |
}), | |
target={ | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
}) | |
train_state = train_state_lib.FlaxOptimTrainState(optimizer) | |
grad_accum = jax.tree_map(lambda x: i, train_state) | |
# Add j to each metric. | |
j = batch['j'].sum() | |
metrics = {'loss': metrics_lib.Sum(j), 'accuracy': metrics_lib.Sum(j)} | |
return (None, metrics), grad_accum.params | |
return fake_grad_fn_without_weight_sum | |
class TrainerTest(parameterized.TestCase): | |
def setUp(self): | |
super().setUp() | |
self.init_optimizer = optimizers.Optimizer( | |
optimizers.sgd(0.1), | |
state=optimizers.OptimizerState( | |
step=0, param_states={ | |
'bias': 0, | |
'kernel': 0 | |
}), | |
target={ | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
}) | |
self.init_train_state = train_state_lib.FlaxOptimTrainState( | |
self.init_optimizer) | |
train_state_axes = jax.tree_map(lambda x: None, self.init_train_state) | |
model_dir = self.create_tempdir().full_path | |
mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} | |
self.dataset = tf.data.Dataset.range(6).map(mapfn).batch( | |
2, drop_remainder=True) | |
self.test_trainer = trainer_lib.Trainer( | |
mock.create_autospec(models_lib.BaseModel, instance=True), | |
self.init_train_state, | |
partitioning.PjitPartitioner(num_partitions=1), | |
eval_names=['task1', 'task2'], | |
summary_dir=model_dir, | |
train_state_axes=train_state_axes, | |
rng=np.ones(2, np.uint32), | |
learning_rate_fn=lambda step: 2 * step, | |
num_microbatches=None) | |
def tearDown(self) -> None: | |
self.test_trainer.close() | |
return super().tearDown() | |
def _test_train(self, precompile): | |
trainer = self.test_trainer | |
initial_rng = trainer._base_rng | |
if precompile: | |
with mock.patch( | |
'time.time', | |
side_effect=[0, 1] # compile start, end | |
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() | |
trainer.compile_train(next(self.dataset.as_numpy_iterator())) | |
trainer._compiled_train_step = mock.Mock( | |
side_effect=trainer._compiled_train_step) | |
trainer._partitioned_train_step = mock.Mock( | |
side_effect=trainer._partitioned_train_step) | |
num_steps = 2 | |
with mock.patch( | |
'time.time', | |
side_effect=[1, 5] # start_time, end_time | |
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() | |
trainer.train(self.dataset.as_numpy_iterator(), num_steps).result() | |
initial_metrics = { | |
'loss': 0., | |
'accuracy': 0., | |
} | |
expected_metrics = { | |
k: (v + 2 * num_steps) for k, v in initial_metrics.items() | |
} | |
# (0 + 2) / 2 = 1 | |
expected_metrics['learning_rate'] = 1 | |
# 0+1+2+3 = 6 | |
expected_train_state = jax.tree_map(lambda x: np.array(x + 6), | |
self.init_train_state) | |
# Base rng must remain the same | |
np.testing.assert_array_equal(trainer._base_rng, initial_rng) | |
jax.tree_multimap(np.testing.assert_equal, trainer.train_state, | |
expected_train_state) | |
# Expected step is 6 since we increment it along with the other optimizer | |
# values. | |
steps = [2, 2, 2] | |
if precompile: | |
steps = [0] + steps | |
expected_metrics['timing/compilation_seconds'] = 1 | |
self.assertEqual(trainer._compiled_train_step.call_count, num_steps) | |
trainer._partitioned_train_step.assert_not_called() | |
else: | |
self.assertIsNone(trainer._compiled_train_step) | |
self.assertEqual(trainer._partitioned_train_step.call_count, num_steps) | |
trainer.train_metrics_manager.flush() | |
_validate_events( | |
self, | |
trainer.train_metrics_manager.summary_dir, | |
expected_metrics, | |
steps=steps) | |
def test_train_noprecompile(self): | |
self._test_train(False) | |
def test_train_precompile(self): | |
self._test_train(True) | |
def _test_eval(self, precompile): | |
trainer = self.test_trainer | |
initial_rng = trainer._base_rng | |
task_datasets = { | |
'task1': self.dataset.take(2), | |
'task2': self.dataset.repeat().take(5) | |
} | |
if precompile: | |
# [task1 start, task1 end, task2 start, task2 end] | |
with mock.patch( | |
'time.time', | |
side_effect=[0, 1, 2, 3] # [t1 start, t1 end, t2 start, t2 end] | |
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() | |
trainer.compile_eval({ | |
task: next(ds.as_numpy_iterator()) | |
for task, ds in task_datasets.items() | |
}) | |
trainer._compiled_eval_steps = { | |
task: mock.Mock(side_effect=trainer._compiled_eval_steps[task]) | |
for task in task_datasets | |
} | |
trainer._partitioned_eval_step = mock.Mock( | |
side_effect=trainer._partitioned_eval_step) | |
with mock.patch( | |
'time.time', | |
side_effect=[1, 5, 5, 8] # t1 start, t1 end, t2 start, t2 end] | |
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() | |
trainer.eval( | |
{task: ds.as_numpy_iterator() for task, ds in task_datasets.items()}) | |
all_expected_metrics = { | |
# 0+1+2+3 = 6 | |
'task1': { | |
'loss': 6, | |
'accuracy': 6, | |
}, | |
# 0+1+2+3+4+5+0+1+2+3 = 21 | |
'task2': { | |
'loss': 21, | |
'accuracy': 21, | |
}, | |
} | |
np.testing.assert_array_equal(trainer._base_rng, initial_rng) | |
for task_name, expected_metrics in all_expected_metrics.items(): | |
steps = [0, 0] | |
if precompile: | |
steps = [0] + steps | |
expected_metrics['timing/compilation_seconds'] = 1 | |
self.assertEqual( # pylint:disable=g-generic-assert | |
trainer._compiled_eval_steps[task_name].call_count, | |
len(task_datasets[task_name])) | |
trainer._partitioned_eval_step.assert_not_called() | |
else: | |
self.assertEmpty(trainer._compiled_eval_steps) | |
self.assertEqual(trainer._partitioned_eval_step.call_count, | |
sum(len(ds) for ds in task_datasets.values())) | |
mm = trainer.eval_metrics_managers[task_name] | |
mm.flush() | |
_validate_events(self, mm.summary_dir, expected_metrics, steps=steps) | |
def test_eval_noprecompile(self): | |
self._test_eval(False) | |
def test_eval_precompile(self): | |
self._test_eval(True) | |
def test_early_stopping_action(self, mode, metrics, atol, rtol, | |
stop_training): | |
trainer = self.test_trainer | |
metrics = [clu.values.Scalar(metric) for metric in metrics] | |
hook = trainer_lib.EarlyStoppingAction(('test_task', 'metric'), | |
mode=mode, | |
patience=3, | |
atol=atol, | |
rtol=rtol) | |
for metric in metrics: | |
trainer_stop_training = hook.run(trainer.train_state, | |
{'test_task': { | |
'metric': metric | |
}}) | |
self.assertEqual(trainer_stop_training, stop_training) | |
def test_early_stopping_action_error(self, task, metric, value): | |
trainer = self.test_trainer | |
hook = trainer_lib.EarlyStoppingAction((task, metric), | |
mode='min', | |
patience=5, | |
atol=1, | |
rtol=1) | |
trainer_stop_training = hook.run(trainer.train_state, | |
{task: { | |
metric: value | |
}}) | |
self.assertFalse(trainer_stop_training) | |
def test_terminate_on_nan_action(self, metric, value, stop_training): | |
trainer = self.test_trainer | |
value = clu.values.Scalar(value) | |
hook = trainer_lib.TerminateOnNanAction(task='test_task', metric=metric) | |
trainer_stop_training = hook.run(trainer.train_state, | |
{'test_task': { | |
metric: value | |
}}) | |
self.assertEqual(trainer_stop_training, stop_training) | |
def test_terminate_on_nan_action_error(self, task, metric, value): | |
trainer = self.test_trainer | |
hook = trainer_lib.TerminateOnNanAction(task=task, metric=metric) | |
trainer_stop_training = hook.run(trainer.train_state, | |
{'task': { | |
'metric': value | |
}}) | |
self.assertFalse(trainer_stop_training) | |
def test_compile_train(self): | |
trainer = self.test_trainer | |
trainer._partitioned_train_step = mock.Mock() | |
trainer.train_metrics_manager = mock.Mock() | |
batch = { | |
'i': np.arange(10, dtype=np.int32).reshape((2, 5)), | |
'j': np.ones((), dtype=np.float32) | |
} | |
# compile start, compile end | |
with mock.patch('time.time', side_effect=[1, 5]): | |
trainer.compile_train(batch) | |
trainer.train_metrics_manager.write_scalar.assert_called_with( | |
'timing/compilation_seconds', 4, trainer.train_state.step) | |
trainer._partitioned_train_step.lower.assert_called_once() | |
train_step_args = trainer._partitioned_train_step.lower.call_args[0] | |
self.assertLen(train_step_args, 2) | |
self.assertEqual(train_step_args[0], trainer.train_state) | |
test_utils.assert_same(train_step_args[1], batch) | |
def test_compile_eval(self): | |
trainer = self.test_trainer | |
trainer._partitioned_eval_step = mock.Mock() | |
trainer.eval_metrics_managers = { | |
'eval1': mock.Mock(), | |
'eval2': mock.Mock(), | |
'eval3': mock.Mock(), | |
'eval4': mock.Mock() | |
} | |
trainer._partitioned_eval_step.lower().compile.side_effect = [ | |
'compiled1', 'compiled2', 'compiled3' | |
] | |
batches = { | |
'eval1': { | |
'i': np.zeros((2, 5), dtype=np.int32) | |
}, | |
'eval2': { | |
'j': np.zeros((), dtype=np.float32) | |
}, | |
'eval3': { | |
'j': np.zeros((), dtype=np.float32) | |
}, | |
'eval4': { | |
'k': np.zeros((4), dtype=np.float32) | |
}, | |
} | |
# eval1 start/end, eval2 start/end, eval3 start/end, eval 4 start/end | |
with mock.patch('time.time', side_effect=[1, 5, 6, 9, 10, 11, 12, 13]): | |
trainer.compile_eval(collections.OrderedDict(sorted(batches.items()))) | |
trainer.eval_metrics_managers['eval1'].write_scalar.assert_called_with( | |
'timing/compilation_seconds', 4, trainer.train_state.step) | |
trainer.eval_metrics_managers['eval2'].write_scalar.assert_called_with( | |
'timing/compilation_seconds', 3, trainer.train_state.step) | |
trainer.eval_metrics_managers['eval3'].write_scalar.assert_called_with( | |
'timing/compilation_seconds', 1, trainer.train_state.step) | |
trainer.eval_metrics_managers['eval4'].write_scalar.assert_called_with( | |
'timing/compilation_seconds', 1, trainer.train_state.step) | |
eval_step_args = trainer._partitioned_eval_step.lower.call_args_list[1:] | |
self.assertLen(eval_step_args, 3) | |
eval1_call_args = eval_step_args[0][0] | |
self.assertLen(eval1_call_args, 2) | |
self.assertEqual(eval1_call_args[0], trainer.train_state) | |
test_utils.assert_same(eval1_call_args[1], { | |
'i': np.zeros((2, 5), dtype=np.int32), | |
}) | |
eval2_call_args = eval_step_args[1][0] | |
self.assertLen(eval2_call_args, 2) | |
self.assertEqual(eval2_call_args[0], trainer.train_state) | |
test_utils.assert_same(eval2_call_args[1], { | |
'j': np.zeros((), dtype=np.float32), | |
}) | |
eval3_call_args = eval_step_args[2][0] | |
self.assertLen(eval3_call_args, 2) | |
self.assertEqual(eval3_call_args[0], trainer.train_state) | |
test_utils.assert_same(eval3_call_args[1], { | |
'k': np.zeros((4), dtype=np.float32), | |
}) | |
self.assertDictEqual( | |
trainer._compiled_eval_steps, { | |
'eval1': 'compiled1', | |
'eval2': 'compiled2', | |
'eval3': 'compiled2', | |
'eval4': 'compiled3' | |
}) | |
def test_accumulate_grads_microbatched_without_weight_sum_single_batch(self): | |
batch_iter = self.dataset.as_numpy_iterator() | |
batch = next(batch_iter) | |
num_microbatches = 1 | |
grad_accum, metrics, flax_mutables = trainer_lib.accumulate_grads_microbatched( | |
self.test_trainer._model, self.init_train_state, batch, | |
self.test_trainer._base_rng, num_microbatches) | |
i = batch['i'].sum() | |
expected_grad_accum = jax.tree_map(lambda x: i, | |
self.init_train_state).params | |
self.assertEqual(expected_grad_accum, grad_accum) | |
self.assertEqual(metrics['loss'].compute(), 2) | |
self.assertEqual(metrics['accuracy'].compute(), 2) | |
self.assertIsNone(flax_mutables) | |
def test_accumulate_grads_microbatched_without_weight_sum_multiple_batches( | |
self): | |
batch_iter = self.dataset.as_numpy_iterator() | |
batch = next(batch_iter) | |
num_micro_batches = 2 | |
grad_accum, metrics, flax_mutables = trainer_lib.accumulate_grads_microbatched( | |
self.test_trainer._model, self.init_train_state, batch, | |
self.test_trainer._base_rng, num_micro_batches) | |
expected_grad_accum = {'bias': jnp.ones(4), 'kernel': jnp.ones((2, 4))} | |
chex.assert_trees_all_equal(expected_grad_accum, grad_accum) | |
self.assertEqual(metrics['loss'].compute(), 2) | |
self.assertEqual(metrics['accuracy'].compute(), 2) | |
self.assertIsNone(flax_mutables) | |
def test_eval_step_without_weight_sum(self): | |
batch_iter = self.dataset.as_numpy_iterator() | |
batch = next(batch_iter) | |
self.test_trainer._model.eval_fn = fake_eval_fn_without_weight_sum | |
metrics = trainer_lib.eval_step(self.test_trainer._model, | |
self.init_train_state, batch) | |
self.assertEqual(metrics['loss'].compute(), 1) | |
self.assertEqual(metrics['accuracy'].compute(), 1) | |
class TrainerRngDeterminismTest(parameterized.TestCase): | |
def create_trainer(self, step, random_seed): | |
init_optimizer = optimizers.Optimizer( | |
optimizers.sgd(0.1), | |
state=optimizers.OptimizerState( | |
step=step, param_states={ | |
'bias': 0, | |
'kernel': 0 | |
}), | |
target={ | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
}) | |
init_train_state = train_state_lib.FlaxOptimTrainState(init_optimizer) | |
train_state_axes = jax.tree_map(lambda x: None, init_train_state) | |
test_trainer = trainer_lib.Trainer( | |
mock.create_autospec(models_lib.BaseModel, instance=True), | |
init_train_state, | |
partitioning.PjitPartitioner(num_partitions=1), | |
eval_names=['task1', 'task2'], | |
summary_dir=None, | |
train_state_axes=train_state_axes, | |
rng=jax.random.PRNGKey(random_seed), | |
learning_rate_fn=lambda step: 2 * step, | |
num_microbatches=None) | |
return test_trainer | |
def test_rng_determinism(self, mock_accum_grads): | |
def fake_accum_grads_rng(model, optimizer, batch, rng, num_microbatches, | |
data_partition_spec): | |
del model, batch, num_microbatches, data_partition_spec | |
# Add 1, which will increment the step as a side effect. | |
grad_accum = jax.tree_map(lambda x: 1, optimizer) | |
m = {'rng': metrics_lib.Sum(jnp.sum(rng))} | |
return grad_accum, m, None | |
mock_accum_grads.side_effect = fake_accum_grads_rng | |
# Create a trainer at a given step (53) with a given random seed (23), | |
# train up to a given train step (100), check the sum of the rngs from the | |
# metrics. | |
start_step = 47 | |
end_step = 100 | |
random_seed = 23 | |
trainer = self.create_trainer(step=start_step, random_seed=random_seed) | |
# 500 batches of size 2 | |
ds = [np.zeros(2)] * 500 | |
metrics = trainer.train(iter(ds), num_steps=end_step - start_step) | |
base_rng = jax.random.PRNGKey(random_seed) | |
expected_rng_sum = np.sum( | |
[jax.random.fold_in(base_rng, i) for i in range(start_step, end_step)], | |
dtype=np.uint32) | |
np.testing.assert_array_equal(metrics.result()['rng'].value, | |
expected_rng_sum) | |
def fake_mut_accum_grads(model, optimizer, batch, rng, num_microbatches, | |
data_partition_spec): | |
del model, num_microbatches, rng, data_partition_spec | |
# Add `i` to each optimzer value. | |
i = batch['i'].sum() | |
grad_accum = jax.tree_map(lambda x: i, optimizer) | |
# Add j to each metric. | |
j = batch['j'].sum() | |
metrics = { | |
'loss': metrics_lib.Sum.from_model_output(j), | |
'accuracy': metrics_lib.Sum.from_model_output(j) | |
} | |
return grad_accum, metrics, {'mutables': 0} | |
def fake_mut_apply_grads(optimizer, grad_accum, metrics, learning_rate, | |
weight_metrics_computer, other_state_variables): | |
del weight_metrics_computer, other_state_variables | |
metrics['learning_rate'] = clu.metrics.Average.from_model_output( | |
learning_rate) | |
optimizer = jax.tree_multimap(lambda x, g: x + g, optimizer, grad_accum) | |
return optimizer, metrics | |
class MutableTrainerTest(parameterized.TestCase): | |
def setUp(self): | |
super().setUp() | |
self.init_optimizer = optimizers.Optimizer( | |
optimizers.sgd(0.1), | |
state=optimizers.OptimizerState( | |
step=0, param_states={ | |
'bias': 0, | |
'kernel': 0 | |
}), | |
target={ | |
'bias': np.zeros(4), | |
'kernel': np.zeros((2, 4)) | |
}) | |
self.init_train_state = train_state_lib.FlaxOptimTrainState( | |
self.init_optimizer) | |
train_state_axes = jax.tree_map(lambda x: None, self.init_train_state) | |
model_dir = self.create_tempdir().full_path | |
mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} | |
self.dataset = tf.data.Dataset.range(6).map(mapfn).batch( | |
2, drop_remainder=True) | |
self.dataset1 = tf.data.Dataset.range(6).map(mapfn).batch( | |
2, drop_remainder=True) | |
self.test_trainer = trainer_lib.Trainer( | |
mock.create_autospec(models_lib.BaseModel, instance=True), | |
self.init_train_state, | |
partitioning.PjitPartitioner(num_partitions=1), | |
eval_names=['task1', 'task2'], | |
summary_dir=model_dir, | |
train_state_axes=train_state_axes, | |
rng=np.ones(2, np.uint32), | |
learning_rate_fn=lambda step: 2 * (step + 1), | |
num_microbatches=None) | |
# avoids calls time.time() during logging | |
def test_train(self, mock_time=None): | |
trainer = self.test_trainer | |
initial_rng = trainer._base_rng | |
trainer._partitioned_train_step = mock.Mock( | |
side_effect=trainer._partitioned_train_step) | |
# train start, logging, train end, logging | |
mock_time.side_effect = [1, 5, 5, 5] | |
num_steps = 1 | |
ds_iter = self.dataset.as_numpy_iterator() | |
batch = next(ds_iter) | |
train_state, _ = trainer._partitioned_train_step(trainer.train_state, batch) | |
expected_train_state = jax.tree_map(lambda x: np.array(x + 1), | |
self.init_train_state) | |
# Base rng must remain the same | |
np.testing.assert_array_equal(trainer._base_rng, initial_rng) | |
jax.tree_multimap(np.testing.assert_equal, train_state, | |
expected_train_state) | |
self.assertIsNone(trainer._compiled_train_step) | |
self.assertEqual(trainer._partitioned_train_step.call_count, num_steps) | |
def tearDown(self) -> None: | |
# Manually close managers to avoid phantom threads crossing test cases. | |
self.test_trainer.train_metrics_manager.close() | |
for mm in self.test_trainer.eval_metrics_managers.values(): | |
mm.close() | |
return super().tearDown() | |
if __name__ == '__main__': | |
absltest.main() | |