Spaces:
Build error
Build error
File size: 33,927 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 |
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for t5x.trainer_lib."""
import collections
import contextlib
import os
from absl.testing import absltest
from absl.testing import parameterized
import chex
from clu import metric_writers
import clu.metrics
import clu.values
import flax
import jax
import jax.numpy as jnp
import numpy as np
from t5x import metrics as metrics_lib
from t5x import models as models_lib
from t5x import optimizers
from t5x import partitioning
from t5x import test_utils
from t5x import train_state as train_state_lib
from t5x import trainer as trainer_lib
import tensorflow as tf
from tensorflow.io import gfile
mock = absltest.mock
jax.config.parse_flags_with_absl()
# Make `log_elapsed_time` a no-op to simplify mocking of `time.time()`.
@contextlib.contextmanager
def fake_log_elapsed_time(_):
yield
jax._src.dispatch.log_elapsed_time = fake_log_elapsed_time
def _validate_events(test_case, summary_dir, expected_metrics, steps):
summaries = gfile.listdir(summary_dir)
test_case.assertLen(summaries, 1)
summary_path = os.path.join(summary_dir, summaries[0])
event_file = os.path.join(summary_path)
events = list(tf.compat.v1.train.summary_iterator(event_file))
actual_events = {}
# First event is boilerplate
test_case.assertLen(events, len(steps) + 1)
for step, event in zip(steps, events[1:]):
test_case.assertEqual(event.step, step)
test_case.assertLen(event.summary.value, 1)
tensor = event.summary.value[0].tensor
if tensor.string_val:
actual_events[event.summary.value[0].tag] = tensor.string_val[0].decode()
else:
actual_events[event.summary.value[0].tag] = float(tf.make_ndarray(tensor))
jax.tree_multimap(test_case.assertAlmostEqual, actual_events,
expected_metrics)
class MetricsManagerTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.model_dir = self.create_tempdir().full_path
def test_summary_dir(self):
# All hosts have the summary dir.
with mock.patch('jax.process_index', return_value=0):
mm = trainer_lib.MetricsManager('eval', self.model_dir)
self.assertEqual(mm.summary_dir, os.path.join(self.model_dir, 'eval'))
mm.close()
with mock.patch('jax.process_index', return_value=1):
mm = trainer_lib.MetricsManager('eval', self.model_dir)
self.assertEqual(mm.summary_dir, os.path.join(self.model_dir, 'eval'))
mm.close()
def test_summary_writer(self):
# Only host 0 creates a non-empty summary writer.
with mock.patch('jax.process_index', return_value=1):
mm = trainer_lib.MetricsManager('eval', self.model_dir)
self.assertFalse(gfile.exists(mm.summary_dir))
mm.close()
with mock.patch('jax.process_index', return_value=0):
mm = trainer_lib.MetricsManager('eval', self.model_dir)
self.assertIsInstance(mm.summary_writer, metric_writers.MetricWriter)
self.assertTrue(gfile.exists(mm.summary_dir))
mm.close()
def test_write_scalar(self):
gfile.makedirs(os.path.join(self.model_dir, 'eval'))
# tag, value, step
scalars = [('loss', 1.0, 1), ('accuracy', 100.0, 2)]
# Only host 0 has actually writes summaries.
with mock.patch('jax.process_index', return_value=1):
mm = trainer_lib.MetricsManager('eval', self.model_dir)
for s in scalars:
mm.write_scalar(*s)
self.assertEmpty(gfile.listdir(mm.summary_dir))
mm.close()
with mock.patch('jax.process_index', return_value=0):
mm = trainer_lib.MetricsManager('eval', self.model_dir)
for s in scalars:
mm.write_scalar(*s)
mm.flush()
summaries = gfile.listdir(mm.summary_dir)
self.assertLen(summaries, 1)
event_file = os.path.join(mm.summary_dir, summaries[0])
events = list(tf.compat.v1.train.summary_iterator(event_file))
# First event is boilerplate
self.assertLen(events, 3)
for event, (tag, value, step) in zip(events[1:], scalars):
self.assertEqual(event.step, step)
self.assertLen(event.summary.value, 1)
self.assertEqual(event.summary.value[0].tag, tag)
self.assertEqual(tf.make_ndarray(event.summary.value[0].tensor), value)
mm.close()
def test_write_metrics_summary(self):
gfile.makedirs(os.path.join(self.model_dir, 'eval'))
@flax.struct.dataclass
class MockTextMetric(clu.metrics.Metric):
def compute_value(self):
return clu.values.Text('test metric')
accumulated_metrics = {
'loss': metrics_lib.Sum(40.0),
'accuracy': metrics_lib.AveragePerStep.from_model_output(20.0),
'steps_per_second': metrics_lib.StepsPerTime(),
'text': MockTextMetric()
}
expected_values = {
'loss': clu.values.Scalar(40.0),
'accuracy': clu.values.Scalar(10.0),
'steps_per_second': clu.values.Scalar(0.05),
'text': clu.values.Text('test metric')
}
with mock.patch(
'jax.process_index', return_value=0), mock.patch(
'time.time',
side_effect=[0, 40] # start_time, end_time
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time()
mm = trainer_lib.MetricsManager('eval', summary_dir=self.model_dir)
mm.start_duration_timer()
summary = mm.write_metrics_summary(
accumulated_metrics, step=4, num_steps=2)
mm.flush()
self.assertDictEqual(summary.result(), expected_values)
_validate_events(
self,
mm.summary_dir, {k: v.value for k, v in expected_values.items()},
steps=[4, 4, 4, 4])
mm.close()
def test_timer_blocking_on_donated_buffer(self):
mm = trainer_lib.MetricsManager('train', summary_dir=None)
x = jnp.zeros(1)
# Not deleted.
mm.start_duration_timer(block_on=x)
mm._duration_timer._start_future.result()
# Deleted/donated.
x.device_buffer.delete()
mm.start_duration_timer(block_on=x)
mm._duration_timer._start_future.result()
def test_timer_concurrency(self):
mm = trainer_lib.MetricsManager('train')
n = 10
with mock.patch(
'time.time',
side_effect=range(2 * n) # start_time, end_time
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time()
for _ in range(n):
mm.start_duration_timer()
summary = mm.write_metrics_summary({'time': metrics_lib.Time()}, 0, 1)
self.assertEqual(1, summary.result()['time'].value)
mm.flush()
def fake_accum_grads(model, optimizer, batch, rng, num_microbatches,
data_partition_spec):
del model, num_microbatches, rng, data_partition_spec
# Add `i` to each optimzer value.
i = batch['i'].sum()
grad_accum = jax.tree_map(lambda x: i, optimizer)
# Add j to each metric.
j = batch['j'].sum()
metrics = {'loss': metrics_lib.Sum(j), 'accuracy': metrics_lib.Sum(j)}
return grad_accum, metrics, None
def fake_apply_grads(optimizer,
grad_accum,
metrics,
learning_rate,
weight_metrics_computer,
other_state_variables=None):
del weight_metrics_computer
del other_state_variables
metrics['learning_rate'] = clu.metrics.Average(learning_rate, count=1)
optimizer = jax.tree_multimap(lambda x, g: x + g, optimizer, grad_accum)
return optimizer, metrics
def fake_eval_step(model, optimizer, batch):
del model, optimizer
# Add `i` to each metric.
i = batch['i'].sum()
return {'loss': metrics_lib.Sum(i), 'accuracy': metrics_lib.Sum(i)}
def fake_eval_fn_without_weight_sum(params, batch):
del params
# Add `i` to each metric.
i = batch['i'].sum()
loss = metrics_lib.Sum(i)
return loss, {'loss': loss, 'accuracy': metrics_lib.Sum(i)}
def fake_value_and_grad_fn_without_weight_sum(callable_fn, has_aux=False):
del callable_fn, has_aux
def fake_grad_fn_without_weight_sum(train_state_params,
batch,
dropout_rng,
flax_mutables=None):
del dropout_rng, train_state_params, flax_mutables
# Add `i` to each optimzer value.
i = batch['i'].sum()
optimizer = optimizers.Optimizer(
optimizers.sgd(0.1),
state=optimizers.OptimizerState(
step=0, param_states={
'bias': 0,
'kernel': 0
}),
target={
'bias': np.zeros(4),
'kernel': np.zeros((2, 4))
})
train_state = train_state_lib.FlaxOptimTrainState(optimizer)
grad_accum = jax.tree_map(lambda x: i, train_state)
# Add j to each metric.
j = batch['j'].sum()
metrics = {'loss': metrics_lib.Sum(j), 'accuracy': metrics_lib.Sum(j)}
return (None, metrics), grad_accum.params
return fake_grad_fn_without_weight_sum
class TrainerTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.init_optimizer = optimizers.Optimizer(
optimizers.sgd(0.1),
state=optimizers.OptimizerState(
step=0, param_states={
'bias': 0,
'kernel': 0
}),
target={
'bias': np.zeros(4),
'kernel': np.zeros((2, 4))
})
self.init_train_state = train_state_lib.FlaxOptimTrainState(
self.init_optimizer)
train_state_axes = jax.tree_map(lambda x: None, self.init_train_state)
model_dir = self.create_tempdir().full_path
mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]}
self.dataset = tf.data.Dataset.range(6).map(mapfn).batch(
2, drop_remainder=True)
self.test_trainer = trainer_lib.Trainer(
mock.create_autospec(models_lib.BaseModel, instance=True),
self.init_train_state,
partitioning.PjitPartitioner(num_partitions=1),
eval_names=['task1', 'task2'],
summary_dir=model_dir,
train_state_axes=train_state_axes,
rng=np.ones(2, np.uint32),
learning_rate_fn=lambda step: 2 * step,
num_microbatches=None)
def tearDown(self) -> None:
self.test_trainer.close()
return super().tearDown()
@mock.patch('t5x.trainer.accumulate_grads_microbatched', fake_accum_grads)
@mock.patch('t5x.trainer.apply_grads', fake_apply_grads)
def _test_train(self, precompile):
trainer = self.test_trainer
initial_rng = trainer._base_rng
if precompile:
with mock.patch(
'time.time',
side_effect=[0, 1] # compile start, end
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time()
trainer.compile_train(next(self.dataset.as_numpy_iterator()))
trainer._compiled_train_step = mock.Mock(
side_effect=trainer._compiled_train_step)
trainer._partitioned_train_step = mock.Mock(
side_effect=trainer._partitioned_train_step)
num_steps = 2
with mock.patch(
'time.time',
side_effect=[1, 5] # start_time, end_time
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time()
trainer.train(self.dataset.as_numpy_iterator(), num_steps).result()
initial_metrics = {
'loss': 0.,
'accuracy': 0.,
}
expected_metrics = {
k: (v + 2 * num_steps) for k, v in initial_metrics.items()
}
# (0 + 2) / 2 = 1
expected_metrics['learning_rate'] = 1
# 0+1+2+3 = 6
expected_train_state = jax.tree_map(lambda x: np.array(x + 6),
self.init_train_state)
# Base rng must remain the same
np.testing.assert_array_equal(trainer._base_rng, initial_rng)
jax.tree_multimap(np.testing.assert_equal, trainer.train_state,
expected_train_state)
# Expected step is 6 since we increment it along with the other optimizer
# values.
steps = [2, 2, 2]
if precompile:
steps = [0] + steps
expected_metrics['timing/compilation_seconds'] = 1
self.assertEqual(trainer._compiled_train_step.call_count, num_steps)
trainer._partitioned_train_step.assert_not_called()
else:
self.assertIsNone(trainer._compiled_train_step)
self.assertEqual(trainer._partitioned_train_step.call_count, num_steps)
trainer.train_metrics_manager.flush()
_validate_events(
self,
trainer.train_metrics_manager.summary_dir,
expected_metrics,
steps=steps)
def test_train_noprecompile(self):
self._test_train(False)
def test_train_precompile(self):
self._test_train(True)
@mock.patch('t5x.trainer.eval_step', fake_eval_step)
def _test_eval(self, precompile):
trainer = self.test_trainer
initial_rng = trainer._base_rng
task_datasets = {
'task1': self.dataset.take(2),
'task2': self.dataset.repeat().take(5)
}
if precompile:
# [task1 start, task1 end, task2 start, task2 end]
with mock.patch(
'time.time',
side_effect=[0, 1, 2, 3] # [t1 start, t1 end, t2 start, t2 end]
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time()
trainer.compile_eval({
task: next(ds.as_numpy_iterator())
for task, ds in task_datasets.items()
})
trainer._compiled_eval_steps = {
task: mock.Mock(side_effect=trainer._compiled_eval_steps[task])
for task in task_datasets
}
trainer._partitioned_eval_step = mock.Mock(
side_effect=trainer._partitioned_eval_step)
with mock.patch(
'time.time',
side_effect=[1, 5, 5, 8] # t1 start, t1 end, t2 start, t2 end]
), mock.patch('absl.logging.log'): # avoids hidden calls to time.time()
trainer.eval(
{task: ds.as_numpy_iterator() for task, ds in task_datasets.items()})
all_expected_metrics = {
# 0+1+2+3 = 6
'task1': {
'loss': 6,
'accuracy': 6,
},
# 0+1+2+3+4+5+0+1+2+3 = 21
'task2': {
'loss': 21,
'accuracy': 21,
},
}
np.testing.assert_array_equal(trainer._base_rng, initial_rng)
for task_name, expected_metrics in all_expected_metrics.items():
steps = [0, 0]
if precompile:
steps = [0] + steps
expected_metrics['timing/compilation_seconds'] = 1
self.assertEqual( # pylint:disable=g-generic-assert
trainer._compiled_eval_steps[task_name].call_count,
len(task_datasets[task_name]))
trainer._partitioned_eval_step.assert_not_called()
else:
self.assertEmpty(trainer._compiled_eval_steps)
self.assertEqual(trainer._partitioned_eval_step.call_count,
sum(len(ds) for ds in task_datasets.values()))
mm = trainer.eval_metrics_managers[task_name]
mm.flush()
_validate_events(self, mm.summary_dir, expected_metrics, steps=steps)
def test_eval_noprecompile(self):
self._test_eval(False)
def test_eval_precompile(self):
self._test_eval(True)
@parameterized.named_parameters([
{
'testcase_name': 'max_no_increase',
'mode': 'max',
'metrics': [1, 1, 1],
'atol': 0.,
'rtol': 0.,
'stop_training': True,
},
{
'testcase_name': 'max_no_atol',
'mode': 'max',
'metrics': [1, 0.9, 0.8],
'atol': 0.,
'rtol': 0.,
'stop_training': True,
},
{
'testcase_name': 'max_not_enough_atol',
'mode': 'max',
'metrics': [1, 1.09, 1.18],
'atol': 0.1,
'rtol': 0.,
'stop_training': True,
},
{
'testcase_name': 'max_enough_atol',
'mode': 'max',
'metrics': [1, 1.2, 1.4],
'atol': 0.1,
'rtol': 0.,
'stop_training': False,
},
{
'testcase_name': 'max_enough_atol_rtol',
'mode': 'max',
# first delta = 0.1 + 1* 0.08 = 0.18
# second delta = 0.1 + 1.2 * 0.08 = 0.196
'metrics': [1, 1.2, 1.4],
'atol': 0.1,
'rtol': 0.08,
'stop_training': False,
},
{
'testcase_name': 'max_not_enough_rtol',
'mode': 'max',
'metrics': [1, 1.2, 1.4],
'atol': 0.,
'rtol': 0.2,
'stop_training': True,
},
{
'testcase_name': 'min_no_decrease',
'mode': 'min',
'metrics': [1, 1, 1],
'atol': 0.,
'rtol': 0.,
'stop_training': True,
},
{
'testcase_name': 'min_no_atol',
'mode': 'min',
'metrics': [1, 1, 1],
'atol': 0.,
'rtol': 0.,
'stop_training': True,
},
{
'testcase_name': 'min_not_enough_atol',
'mode': 'min',
'metrics': [1, 0.9, 0.71],
'atol': 0.2,
'rtol': 0.,
'stop_training': True,
},
{
'testcase_name': 'min_enough_atol',
'mode': 'min',
'metrics': [1, 0.8, 0.6],
'atol': 0.15,
'rtol': 0.,
'stop_training': False,
},
{
'testcase_name': 'min_enough_atol_rtol',
'mode': 'min',
# first delta = 0.1 + 1* 0.09 = 0.19
# second delta = 0.1 + 0.8 * 0.09 = 0.172
'metrics': [1, 0.8, 0.6],
'atol': 0.1,
'rtol': 0.09,
'stop_training': False,
},
{
'testcase_name': 'min_not_enough_rtol',
'mode': 'min',
'metrics': [1, 0.8, 0.6],
'atol': 0.0,
'rtol': 0.3,
'stop_training': True,
},
{
'testcase_name': 'longer_history',
'mode': 'min',
'metrics': [1, 0.8, 0.7, 0.6],
'atol': 0.15,
'rtol': 0.,
'stop_training': True,
}
])
def test_early_stopping_action(self, mode, metrics, atol, rtol,
stop_training):
trainer = self.test_trainer
metrics = [clu.values.Scalar(metric) for metric in metrics]
hook = trainer_lib.EarlyStoppingAction(('test_task', 'metric'),
mode=mode,
patience=3,
atol=atol,
rtol=rtol)
for metric in metrics:
trainer_stop_training = hook.run(trainer.train_state,
{'test_task': {
'metric': metric
}})
self.assertEqual(trainer_stop_training, stop_training)
@parameterized.named_parameters([
{
'testcase_name': 'invalid_task',
'task': 'wrong_task',
'metric': 'metric',
'value': clu.values.Scalar(np.nan),
},
{
'testcase_name': 'invalid_metric_name',
'task': 'task',
'metric': 'wrong_metric_name',
'value': clu.values.Scalar(np.nan),
},
{
'testcase_name': 'invalid_value',
'task': 'task',
'metric': 'metric',
'value': 1.0,
},
])
def test_early_stopping_action_error(self, task, metric, value):
trainer = self.test_trainer
hook = trainer_lib.EarlyStoppingAction((task, metric),
mode='min',
patience=5,
atol=1,
rtol=1)
trainer_stop_training = hook.run(trainer.train_state,
{task: {
metric: value
}})
self.assertFalse(trainer_stop_training)
@parameterized.named_parameters([{
'testcase_name': 'valid_loss',
'metric': 'loss',
'value': 1.0,
'stop_training': False,
}, {
'testcase_name': 'nan',
'metric': 'loss',
'value': np.nan,
'stop_training': True,
}, {
'testcase_name': 'inf',
'metric': 'loss',
'value': np.inf,
'stop_training': True,
}, {
'testcase_name': 'other_metric',
'metric': 'some_metric',
'value': np.inf,
'stop_training': True,
}])
def test_terminate_on_nan_action(self, metric, value, stop_training):
trainer = self.test_trainer
value = clu.values.Scalar(value)
hook = trainer_lib.TerminateOnNanAction(task='test_task', metric=metric)
trainer_stop_training = hook.run(trainer.train_state,
{'test_task': {
metric: value
}})
self.assertEqual(trainer_stop_training, stop_training)
@parameterized.named_parameters([
{
'testcase_name': 'invalid_task',
'task': 'wrong_task',
'metric': 'metric',
'value': clu.values.Scalar(np.nan),
},
{
'testcase_name': 'invalid_metric_name',
'task': 'task',
'metric': 'wrong_metric_name',
'value': clu.values.Scalar(np.nan),
},
{
'testcase_name': 'invalid_value',
'task': 'task',
'metric': 'metric',
'value': 1.0,
},
])
def test_terminate_on_nan_action_error(self, task, metric, value):
trainer = self.test_trainer
hook = trainer_lib.TerminateOnNanAction(task=task, metric=metric)
trainer_stop_training = hook.run(trainer.train_state,
{'task': {
'metric': value
}})
self.assertFalse(trainer_stop_training)
def test_compile_train(self):
trainer = self.test_trainer
trainer._partitioned_train_step = mock.Mock()
trainer.train_metrics_manager = mock.Mock()
batch = {
'i': np.arange(10, dtype=np.int32).reshape((2, 5)),
'j': np.ones((), dtype=np.float32)
}
# compile start, compile end
with mock.patch('time.time', side_effect=[1, 5]):
trainer.compile_train(batch)
trainer.train_metrics_manager.write_scalar.assert_called_with(
'timing/compilation_seconds', 4, trainer.train_state.step)
trainer._partitioned_train_step.lower.assert_called_once()
train_step_args = trainer._partitioned_train_step.lower.call_args[0]
self.assertLen(train_step_args, 2)
self.assertEqual(train_step_args[0], trainer.train_state)
test_utils.assert_same(train_step_args[1], batch)
def test_compile_eval(self):
trainer = self.test_trainer
trainer._partitioned_eval_step = mock.Mock()
trainer.eval_metrics_managers = {
'eval1': mock.Mock(),
'eval2': mock.Mock(),
'eval3': mock.Mock(),
'eval4': mock.Mock()
}
trainer._partitioned_eval_step.lower().compile.side_effect = [
'compiled1', 'compiled2', 'compiled3'
]
batches = {
'eval1': {
'i': np.zeros((2, 5), dtype=np.int32)
},
'eval2': {
'j': np.zeros((), dtype=np.float32)
},
'eval3': {
'j': np.zeros((), dtype=np.float32)
},
'eval4': {
'k': np.zeros((4), dtype=np.float32)
},
}
# eval1 start/end, eval2 start/end, eval3 start/end, eval 4 start/end
with mock.patch('time.time', side_effect=[1, 5, 6, 9, 10, 11, 12, 13]):
trainer.compile_eval(collections.OrderedDict(sorted(batches.items())))
trainer.eval_metrics_managers['eval1'].write_scalar.assert_called_with(
'timing/compilation_seconds', 4, trainer.train_state.step)
trainer.eval_metrics_managers['eval2'].write_scalar.assert_called_with(
'timing/compilation_seconds', 3, trainer.train_state.step)
trainer.eval_metrics_managers['eval3'].write_scalar.assert_called_with(
'timing/compilation_seconds', 1, trainer.train_state.step)
trainer.eval_metrics_managers['eval4'].write_scalar.assert_called_with(
'timing/compilation_seconds', 1, trainer.train_state.step)
eval_step_args = trainer._partitioned_eval_step.lower.call_args_list[1:]
self.assertLen(eval_step_args, 3)
eval1_call_args = eval_step_args[0][0]
self.assertLen(eval1_call_args, 2)
self.assertEqual(eval1_call_args[0], trainer.train_state)
test_utils.assert_same(eval1_call_args[1], {
'i': np.zeros((2, 5), dtype=np.int32),
})
eval2_call_args = eval_step_args[1][0]
self.assertLen(eval2_call_args, 2)
self.assertEqual(eval2_call_args[0], trainer.train_state)
test_utils.assert_same(eval2_call_args[1], {
'j': np.zeros((), dtype=np.float32),
})
eval3_call_args = eval_step_args[2][0]
self.assertLen(eval3_call_args, 2)
self.assertEqual(eval3_call_args[0], trainer.train_state)
test_utils.assert_same(eval3_call_args[1], {
'k': np.zeros((4), dtype=np.float32),
})
self.assertDictEqual(
trainer._compiled_eval_steps, {
'eval1': 'compiled1',
'eval2': 'compiled2',
'eval3': 'compiled2',
'eval4': 'compiled3'
})
@mock.patch('jax.value_and_grad', fake_value_and_grad_fn_without_weight_sum)
def test_accumulate_grads_microbatched_without_weight_sum_single_batch(self):
batch_iter = self.dataset.as_numpy_iterator()
batch = next(batch_iter)
num_microbatches = 1
grad_accum, metrics, flax_mutables = trainer_lib.accumulate_grads_microbatched(
self.test_trainer._model, self.init_train_state, batch,
self.test_trainer._base_rng, num_microbatches)
i = batch['i'].sum()
expected_grad_accum = jax.tree_map(lambda x: i,
self.init_train_state).params
self.assertEqual(expected_grad_accum, grad_accum)
self.assertEqual(metrics['loss'].compute(), 2)
self.assertEqual(metrics['accuracy'].compute(), 2)
self.assertIsNone(flax_mutables)
@mock.patch('jax.value_and_grad', fake_value_and_grad_fn_without_weight_sum)
def test_accumulate_grads_microbatched_without_weight_sum_multiple_batches(
self):
batch_iter = self.dataset.as_numpy_iterator()
batch = next(batch_iter)
num_micro_batches = 2
grad_accum, metrics, flax_mutables = trainer_lib.accumulate_grads_microbatched(
self.test_trainer._model, self.init_train_state, batch,
self.test_trainer._base_rng, num_micro_batches)
expected_grad_accum = {'bias': jnp.ones(4), 'kernel': jnp.ones((2, 4))}
chex.assert_trees_all_equal(expected_grad_accum, grad_accum)
self.assertEqual(metrics['loss'].compute(), 2)
self.assertEqual(metrics['accuracy'].compute(), 2)
self.assertIsNone(flax_mutables)
def test_eval_step_without_weight_sum(self):
batch_iter = self.dataset.as_numpy_iterator()
batch = next(batch_iter)
self.test_trainer._model.eval_fn = fake_eval_fn_without_weight_sum
metrics = trainer_lib.eval_step(self.test_trainer._model,
self.init_train_state, batch)
self.assertEqual(metrics['loss'].compute(), 1)
self.assertEqual(metrics['accuracy'].compute(), 1)
class TrainerRngDeterminismTest(parameterized.TestCase):
def create_trainer(self, step, random_seed):
init_optimizer = optimizers.Optimizer(
optimizers.sgd(0.1),
state=optimizers.OptimizerState(
step=step, param_states={
'bias': 0,
'kernel': 0
}),
target={
'bias': np.zeros(4),
'kernel': np.zeros((2, 4))
})
init_train_state = train_state_lib.FlaxOptimTrainState(init_optimizer)
train_state_axes = jax.tree_map(lambda x: None, init_train_state)
test_trainer = trainer_lib.Trainer(
mock.create_autospec(models_lib.BaseModel, instance=True),
init_train_state,
partitioning.PjitPartitioner(num_partitions=1),
eval_names=['task1', 'task2'],
summary_dir=None,
train_state_axes=train_state_axes,
rng=jax.random.PRNGKey(random_seed),
learning_rate_fn=lambda step: 2 * step,
num_microbatches=None)
return test_trainer
@mock.patch('t5x.trainer.accumulate_grads_microbatched')
@mock.patch('t5x.trainer.apply_grads', fake_apply_grads)
def test_rng_determinism(self, mock_accum_grads):
def fake_accum_grads_rng(model, optimizer, batch, rng, num_microbatches,
data_partition_spec):
del model, batch, num_microbatches, data_partition_spec
# Add 1, which will increment the step as a side effect.
grad_accum = jax.tree_map(lambda x: 1, optimizer)
m = {'rng': metrics_lib.Sum(jnp.sum(rng))}
return grad_accum, m, None
mock_accum_grads.side_effect = fake_accum_grads_rng
# Create a trainer at a given step (53) with a given random seed (23),
# train up to a given train step (100), check the sum of the rngs from the
# metrics.
start_step = 47
end_step = 100
random_seed = 23
trainer = self.create_trainer(step=start_step, random_seed=random_seed)
# 500 batches of size 2
ds = [np.zeros(2)] * 500
metrics = trainer.train(iter(ds), num_steps=end_step - start_step)
base_rng = jax.random.PRNGKey(random_seed)
expected_rng_sum = np.sum(
[jax.random.fold_in(base_rng, i) for i in range(start_step, end_step)],
dtype=np.uint32)
np.testing.assert_array_equal(metrics.result()['rng'].value,
expected_rng_sum)
def fake_mut_accum_grads(model, optimizer, batch, rng, num_microbatches,
data_partition_spec):
del model, num_microbatches, rng, data_partition_spec
# Add `i` to each optimzer value.
i = batch['i'].sum()
grad_accum = jax.tree_map(lambda x: i, optimizer)
# Add j to each metric.
j = batch['j'].sum()
metrics = {
'loss': metrics_lib.Sum.from_model_output(j),
'accuracy': metrics_lib.Sum.from_model_output(j)
}
return grad_accum, metrics, {'mutables': 0}
def fake_mut_apply_grads(optimizer, grad_accum, metrics, learning_rate,
weight_metrics_computer, other_state_variables):
del weight_metrics_computer, other_state_variables
metrics['learning_rate'] = clu.metrics.Average.from_model_output(
learning_rate)
optimizer = jax.tree_multimap(lambda x, g: x + g, optimizer, grad_accum)
return optimizer, metrics
class MutableTrainerTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.init_optimizer = optimizers.Optimizer(
optimizers.sgd(0.1),
state=optimizers.OptimizerState(
step=0, param_states={
'bias': 0,
'kernel': 0
}),
target={
'bias': np.zeros(4),
'kernel': np.zeros((2, 4))
})
self.init_train_state = train_state_lib.FlaxOptimTrainState(
self.init_optimizer)
train_state_axes = jax.tree_map(lambda x: None, self.init_train_state)
model_dir = self.create_tempdir().full_path
mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]}
self.dataset = tf.data.Dataset.range(6).map(mapfn).batch(
2, drop_remainder=True)
self.dataset1 = tf.data.Dataset.range(6).map(mapfn).batch(
2, drop_remainder=True)
self.test_trainer = trainer_lib.Trainer(
mock.create_autospec(models_lib.BaseModel, instance=True),
self.init_train_state,
partitioning.PjitPartitioner(num_partitions=1),
eval_names=['task1', 'task2'],
summary_dir=model_dir,
train_state_axes=train_state_axes,
rng=np.ones(2, np.uint32),
learning_rate_fn=lambda step: 2 * (step + 1),
num_microbatches=None)
@mock.patch('time.time')
@mock.patch('t5x.trainer.accumulate_grads_microbatched', fake_mut_accum_grads)
@mock.patch('t5x.trainer.apply_grads', fake_mut_apply_grads)
# avoids calls time.time() during logging
@mock.patch('absl.logging.info', lambda *_: None)
@mock.patch('absl.logging.log_every_n_seconds', lambda *_: None)
def test_train(self, mock_time=None):
trainer = self.test_trainer
initial_rng = trainer._base_rng
trainer._partitioned_train_step = mock.Mock(
side_effect=trainer._partitioned_train_step)
# train start, logging, train end, logging
mock_time.side_effect = [1, 5, 5, 5]
num_steps = 1
ds_iter = self.dataset.as_numpy_iterator()
batch = next(ds_iter)
train_state, _ = trainer._partitioned_train_step(trainer.train_state, batch)
expected_train_state = jax.tree_map(lambda x: np.array(x + 1),
self.init_train_state)
# Base rng must remain the same
np.testing.assert_array_equal(trainer._base_rng, initial_rng)
jax.tree_multimap(np.testing.assert_equal, train_state,
expected_train_state)
self.assertIsNone(trainer._compiled_train_step)
self.assertEqual(trainer._partitioned_train_step.call_count, num_steps)
def tearDown(self) -> None:
# Manually close managers to avoid phantom threads crossing test cases.
self.test_trainer.train_metrics_manager.close()
for mm in self.test_trainer.eval_metrics_managers.values():
mm.close()
return super().tearDown()
if __name__ == '__main__':
absltest.main()
|