youtube-music-transcribe / t5x /checkpoints_test.py
juancopi81's picture
Add t5x and mt3 models
b100e1c
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for t5x.checkpoints."""
import concurrent.futures
import functools
import itertools
import os
from typing import Any, Mapping
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from flax import serialization
from flax import traverse_util
from flax.metrics import tensorboard
import jax
import jax.numpy as jnp
import numpy as np
from t5x import checkpoints
from t5x import optimizers
from t5x import partitioning
from t5x import state_utils
from t5x import test_utils
from t5x import train_state as train_state_lib
from t5x import utils
import tensorflow as tf
from tensorflow.io import gfile
import tensorstore as ts
# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()
mock = absltest.mock
PartitionSpec = partitioning.PartitionSpec
FLAGS = flags.FLAGS
LazyArray = checkpoints.LazyArray
TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'testdata')
FlaxOptimTrainState = train_state_lib.FlaxOptimTrainState
def make_train_state(
*,
step: int,
params: Mapping[str, Any],
param_states: Mapping[str, Any],
flax_optimizer_def: optimizers.OptimizerDefType = optimizers.sgd(0.1)
) -> FlaxOptimTrainState:
"""Helper to construct a train state for testing."""
optimizer = optimizers.Optimizer(
flax_optimizer_def,
state=optimizers.OptimizerState(step=step, param_states=param_states),
target=params)
return FlaxOptimTrainState(optimizer)
def make_train_state_multi_optimizer(params: Mapping[str, Any],
param_states: Mapping[str, Any],
step: int) -> FlaxOptimTrainState:
"""Helper to construct a train state with multi optimizer for testing."""
optimizer = optimizers.Optimizer(
optimizers.MultiOptimizer([
(traverse_util.ModelParamTraversal(
lambda path, _: 'kernel' not in path), optimizers.sgd(0.1)),
]),
state=optimizers.OptimizerState(step=step, param_states=param_states),
target=params)
return FlaxOptimTrainState(optimizer)
def update_train_state_step(train_state: FlaxOptimTrainState,
step: int) -> FlaxOptimTrainState:
"""Helper to update the step inside TrainState."""
state_dict = train_state.state_dict()
state_dict['state']['step'] = step
return train_state.restore_state(state_dict)
class CheckpointChunkShapeTest(absltest.TestCase):
def test_simple(self):
self.assertEqual([4096, 4096],
checkpoints._choose_chunk_shape([4096, 4096], 4096 * 4096))
self.assertEqual([4096, 4096],
checkpoints._choose_chunk_shape([8192, 8192], 4096 * 4096))
self.assertEqual([4096, 2731],
checkpoints._choose_chunk_shape([8192, 8193], 4096 * 4096))
self.assertEqual([4096], checkpoints._choose_chunk_shape([8192], 4096))
self.assertEqual([2731], checkpoints._choose_chunk_shape([8193], 4096))
class CheckpointsTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.train_state = make_train_state(
step=np.int32(42),
params={
'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
},
param_states={
'bias': np.int32(1),
'kernel': np.array([1, 2], np.uint8)
})
self.train_state_multi_optimizer = make_train_state_multi_optimizer(
step=np.int32(42),
params={
'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
},
param_states={
'bias': np.int32(1),
'kernel': None
})
self.default_mesh_axes = make_train_state(
step=None,
params={
'bias': PartitionSpec('model', None),
'kernel': PartitionSpec(None, 'model')
},
param_states={
'bias': None,
'kernel': None
})
self.ds = tf.data.Dataset.range(1024)
self.checkpoints_dir = self.create_tempdir()
self.tmp_dir = self.checkpoints_dir.full_path
fake_checkpoints = self.create_tempdir()
self.fake_checkpoints = fake_checkpoints.full_path
self.steps = (0, 100, 200)
for step in self.steps:
step_dir = fake_checkpoints.mkdir(f'checkpoint_{step}')
step_dir.create_file('checkpoint')
@mock.patch('jax._src.lib.xla_bridge.process_index')
@mock.patch('jax.devices')
@mock.patch('jax.local_devices')
def get_partitioner(self,
process_index,
host_count,
num_partitions,
local_devices_fn,
devices_fn,
process_index_fn,
params_on_devices: bool = True,
mesh_axes=None):
host_count_to_layout = {
1: (2, 2, 1, 2),
2: (4, 2, 1, 2),
4: (4, 4, 1, 2),
8: (4, 8, 1, 2),
16: (8, 8, 1, 2),
32: (8, 16, 1, 2)
}
devices = test_utils.make_devices(*host_count_to_layout[host_count])
devices_fn.return_value = devices
local_devices = [d for d in devices if d.process_index == 0]
local_devices_fn.return_value = local_devices
process_index_fn.return_value = process_index
num_partitions_to_mps = {
1: (1, 1, 1, 1),
2: (1, 1, 1, 2),
4: (2, 1, 1, 2),
16: (4, 2, 1, 2)
}
mesh = partitioning.get_mesh(
model_parallel_submesh=num_partitions_to_mps[num_partitions],
input_devices=devices,
input_local_devices=local_devices)
mesh_axes = mesh_axes or self.default_mesh_axes
local_chunker = partitioning.LocalChunker(mesh)
class TestPartitioner(partitioning.BasePartitioner):
def __init__(self):
self.move_params_to_devices_calls = 0
super().__init__(
num_partitions, None, params_on_devices=params_on_devices)
@property
def _local_chunker(self):
return local_chunker
@property
def _mesh(self):
return mesh
def partition(self,
fn,
in_axis_resources,
out_axis_resources,
static_argnums=(),
donate_argnums=()):
raise NotImplementedError
def compile(self, partitioned_fn, *args):
raise NotImplementedError
def move_params_to_devices(self, train_state, train_state_axes):
assert params_on_devices
return train_state
def get_mesh_axes(self, train_state):
return mesh_axes
return TestPartitioner()
# pylint:disable=no-value-for-parameter
@mock.patch(
'jax.experimental.multihost_utils.sync_global_devices', return_value=None)
@mock.patch('time.time', return_value=0)
@mock.patch('jax.host_count')
@mock.patch('jax.process_index')
def call_host_checkpointer(self,
process_index,
host_count,
partitioner,
fn,
save_dtype,
ds_iter,
mock_process_index,
mock_host_count,
unused_mock_host_time,
unused_mock_sync_devices,
restore_dtype=np.float32):
mock_process_index.return_value = process_index
mock_host_count.return_value = host_count
checkpointer = checkpoints.Checkpointer(
self.train_state,
partitioner,
self.tmp_dir,
ds_iter,
save_dtype=save_dtype,
restore_dtype=restore_dtype)
return fn(checkpointer)
# pylint:disable=no-value-for-parameter
@mock.patch(
'jax.experimental.multihost_utils.sync_global_devices', return_value=None)
@mock.patch('time.time', return_value=0)
@mock.patch('jax.host_count')
@mock.patch('jax.process_index')
def call_host_multioptimizer_checkpointer(self, process_index, host_count,
partitioner, fn, save_dtype,
ds_iter, mock_process_index,
mock_host_count,
unused_mock_host_time,
unused_mock_sync_devices):
mock_process_index.return_value = process_index
mock_host_count.return_value = host_count
checkpointer = checkpoints.Checkpointer(
self.train_state_multi_optimizer,
partitioner,
self.tmp_dir,
ds_iter,
save_dtype=save_dtype)
return fn(checkpointer)
def test_get_parameter_infos(self):
train_state = make_train_state(
params={
'bias': np.ones((8192, 8192), np.float32),
'kernel': np.ones((2, 16), np.float32)
},
param_states={
'bias': np.int32(1),
'kernel': np.array([1, 2])
},
step=np.int32(42))
# host 3 of a 4x4 with mesh 'model' dim == 16
partitioner = self.get_partitioner(3, 4, 16)
checkpointer = checkpoints.Checkpointer(train_state, partitioner,
self.tmp_dir)
expected_parameter_infos = {
'state': {
'step':
checkpoints._ParameterInfo(
name='state/step', shape=(), ts_spec=None, local_chunk_info=None, axes=None),
'param_states': {
'bias':
checkpoints._ParameterInfo(
name='state/param_states/bias',
shape=(),
ts_spec=None,
local_chunk_info=None, axes=None),
'kernel':
checkpoints._ParameterInfo(
name='state/param_states/kernel',
shape=(2,),
ts_spec=None,
local_chunk_info=None, axes=None)
}
},
'target': {
'bias':
checkpoints._ParameterInfo(
name='target/bias',
shape=(8192, 8192),
ts_spec=ts.Spec({
'driver': 'zarr',
'dtype': 'float32',
'kvstore': { # pylint:disable=duplicate-key
'driver': 'file',
'path': 'target.bias',
},
'metadata': {
'chunks': [4096, 4096],
'compressor': {
'id': 'gzip'
},
'shape': [8192, 8192],
},
}),
local_chunk_info=partitioning.LocalChunkInfo(
slice=(slice(4096, 8192, None), slice(None, None,
None)),
replica_id=1), axes=PartitionSpec('model', None)),
'kernel':
checkpoints._ParameterInfo(
name='target/kernel',
shape=(2, 16),
ts_spec=ts.Spec({
'driver': 'zarr',
'dtype': 'float32',
'kvstore': { # pylint:disable=duplicate-key
'driver': 'file',
'path': 'target.kernel',
},
'metadata': {
'chunks': [2, 8],
'compressor': {
'id': 'gzip'
},
'shape': [2, 16],
},
}),
local_chunk_info=partitioning.LocalChunkInfo(
slice=(slice(None, None, None), slice(8, 16, None)),
replica_id=1), axes=PartitionSpec(None, 'model'))
}
} # pyformat: disable
jax.tree_multimap(self.assertEqual, checkpointer._get_parameter_infos(),
expected_parameter_infos)
def test_get_multioptimizer_parameter_infos(self):
train_state = make_train_state(
step=np.int32(42),
params={
'bias': np.ones((8192, 8192), jnp.bfloat16),
'kernel': np.ones((2, 16), np.float32)
},
param_states={
'bias': np.int32(1),
# The parameter state for Kernel is `None` as if we have a
# multioptimizer that is not updating this parameter.
'kernel': None
})
# host 3 of a 4x4 with mesh 'model' dim == 16
partitioner = self.get_partitioner(3, 4, 16)
checkpointer = checkpoints.Checkpointer(train_state, partitioner,
self.tmp_dir)
kernel_state_info = (
checkpointer._get_parameter_infos()['state']['param_states']['kernel'])
self.assertIsNone(kernel_state_info)
def test_all_steps(self):
partitioner = self.get_partitioner(0, 1, 1)
checkpointer = self.call_host_checkpointer(0, 1, partitioner, lambda c: c,
np.float32, None)
self.assertIsNone(checkpointer.latest_step())
for step in ['0', '42', '10', '999.tmp-0', '100']:
d = os.path.join(checkpointer.checkpoints_dir, f'checkpoint_{step}')
gfile.makedirs(d)
ckpt = os.path.join(d, 'checkpoint')
with gfile.GFile(ckpt, 'w') as f:
f.write('')
self.assertSequenceEqual(
checkpoints.all_steps(checkpointer.checkpoints_dir + '/'),
[0, 10, 42, 100])
def test_all_latest_step(self):
partitioner = self.get_partitioner(0, 1, 1)
checkpointer = self.call_host_checkpointer(0, 1, partitioner, lambda c: c,
np.float32, None)
self.assertIsNone(checkpointer.latest_step())
for step in ['0', '42', '10', '999.tmp-0', '100']:
d = os.path.join(checkpointer.checkpoints_dir, f'checkpoint_{step}')
gfile.makedirs(d)
ckpt = os.path.join(d, 'checkpoint')
with gfile.GFile(ckpt, 'w') as f:
f.write('')
self.assertSequenceEqual(checkpointer.all_steps(), [0, 10, 42, 100])
self.assertEqual(checkpointer.latest_step(), 100)
# Remove checkpoint file for step 100 (but leave directory).
gfile.remove(ckpt)
self.assertSequenceEqual(checkpointer.all_steps(), [0, 10, 42])
self.assertEqual(checkpointer.latest_step(), 42)
def test_all_latest_step_public(self):
self.assertIsNone(checkpoints.latest_step(self.tmp_dir))
for step in ['0', '42', '10', '999.tmp-0', '100']:
d = os.path.join(self.tmp_dir, f'checkpoint_{step}')
gfile.makedirs(d)
ckpt = os.path.join(d, 'checkpoint')
with gfile.GFile(ckpt, 'w') as f:
f.write('')
self.assertSequenceEqual(
checkpoints.all_steps(self.tmp_dir), [0, 10, 42, 100])
self.assertEqual(checkpoints.latest_step(self.tmp_dir), 100)
# Remove checkpoint file for step 100 (but leave directory).
gfile.remove(ckpt)
self.assertSequenceEqual(checkpoints.all_steps(self.tmp_dir), [0, 10, 42])
self.assertEqual(checkpoints.latest_step(self.tmp_dir), 42)
def validate_restore(self,
host_count,
num_partitions,
step=42,
checkpoint_dataset=False,
expected_restore_dtype=np.float32,
lazy_parameters=False,
disable_partitioning=False):
params = self.train_state.params
param_states = self.train_state.param_states
for i in range(host_count):
partitioner = self.get_partitioner(
i,
host_count,
num_partitions,
params_on_devices=not lazy_parameters,
mesh_axes=jax.tree_map(lambda x: None, self.default_mesh_axes)
if disable_partitioning else None)
ds_shard_id = partitioner.get_data_layout().shard_id
bias_slice = partitioner.get_local_chunk_info(params['bias'].shape,
('model', None)).slice
kernel_slice = partitioner.get_local_chunk_info(params['kernel'].shape,
(None, 'model')).slice
ds_iter = iter(self.ds)
actual_train_state = self.call_host_checkpointer(
i,
host_count,
partitioner,
lambda c: c.restore( # pylint: disable=g-long-lambda
step=step,
lazy_parameters=lazy_parameters),
np.float32,
ds_iter if checkpoint_dataset else None,
restore_dtype=expected_restore_dtype)
if lazy_parameters:
actual_train_state = jax.tree_map(lambda x: x.get(), actual_train_state)
self.assertEqual(actual_train_state._optimizer.optimizer_def,
self.train_state._optimizer.optimizer_def)
self.assertEqual(actual_train_state.step, step)
self.assertEqual(actual_train_state.step.dtype, np.int32)
self.assertEqual(actual_train_state._optimizer.state.step.dtype, np.int32)
jax.tree_multimap(np.testing.assert_array_equal,
actual_train_state.param_states, param_states)
self.assertEqual(actual_train_state.param_states['kernel'].dtype,
np.uint8)
self.assertSameElements(actual_train_state.params, ('bias', 'kernel'))
self.assertTrue(
all(
jax.tree_leaves(
jax.tree_map(lambda x: x.dtype == expected_restore_dtype,
actual_train_state.params))))
np.testing.assert_equal(actual_train_state.params['bias'],
params['bias'][bias_slice])
np.testing.assert_equal(actual_train_state.params['kernel'],
params['kernel'][kernel_slice])
if checkpoint_dataset:
# The next value from the restored iterator should equal the
# replica set id.
self.assertEqual(next(ds_iter).numpy(), ds_shard_id)
def validate_multioptimizer_restore(self,
host_count,
num_partitions,
step=42,
checkpoint_dataset=False,
expected_restore_dtype=np.float32):
params = self.train_state_multi_optimizer.params
param_states = self.train_state_multi_optimizer.param_states
for i in range(host_count):
partitioner = self.get_partitioner(i, host_count, num_partitions)
ds_shard_id = partitioner.get_data_layout().shard_id
bias_slice = partitioner.get_local_chunk_info(params['bias'].shape,
('model', None)).slice
kernel_slice = partitioner.get_local_chunk_info(params['kernel'].shape,
(None, 'model')).slice
ds_iter = iter(self.ds)
actual_train_state = self.call_host_multioptimizer_checkpointer(
i, host_count, partitioner, lambda c: c.restore(step=step),
np.float32, ds_iter if checkpoint_dataset else None)
actual_optimizer = actual_train_state._optimizer # pylint: disable=protected-access
actual_step = actual_train_state.step
actual_params = actual_train_state.params
actual_param_states = actual_train_state.param_states
self.assertEqual(
actual_optimizer.optimizer_def,
self.train_state_multi_optimizer._optimizer.optimizer_def)
self.assertEqual(actual_optimizer.state.step.dtype, np.int32)
jax.tree_map(lambda x: self.assertEqual(x.dtype, expected_restore_dtype),
actual_optimizer.target)
self.assertEqual(actual_step, step)
self.assertEqual(actual_step.dtype, np.int32)
jax.tree_multimap(np.testing.assert_array_equal, actual_param_states,
param_states)
self.assertSameElements(actual_params, ('bias', 'kernel'))
self.assertTrue(
all(
jax.tree_leaves(
jax.tree_map(lambda x: x.dtype == expected_restore_dtype,
actual_params))))
np.testing.assert_equal(actual_params['bias'], params['bias'][bias_slice])
np.testing.assert_equal(actual_params['kernel'],
params['kernel'][kernel_slice])
if checkpoint_dataset:
# The next value from the restored iterator should equal the
# replica set id.
self.assertEqual(next(ds_iter).numpy(), ds_shard_id)
def validate_save(self,
host_count,
num_partitions,
step=42,
save_dtype=np.float32,
checkpoint_dataset=False,
multi_optimizer=False,
disable_partitioning=False):
if multi_optimizer:
params = self.train_state_multi_optimizer.params
param_states = self.train_state_multi_optimizer.param_states
optimizer_def = self.train_state_multi_optimizer._optimizer.optimizer_def
else:
params = self.train_state.params
param_states = self.train_state.param_states
optimizer_def = self.train_state._optimizer.optimizer_def
# Update these on each save.
step = np.int32(step)
expected_bias = np.zeros((4, 1), save_dtype)
expected_kernel = np.zeros((2, 16), save_dtype)
bias_tspec = {
'driver': 'zarr',
'kvstore': {
'driver': 'file',
'path': f'{self.tmp_dir}/checkpoint_{step}.tmp-0/target.bias',
}
}
kernel_tspec = {
'driver': 'zarr',
'kvstore': {
'driver': 'file',
'path': f'{self.tmp_dir}/checkpoint_{step}.tmp-0/target.kernel',
}
}
# Test save.
# Each host sets its partition to its host number + 1.
# Go in reverse since host 0 renames the directory.
for i in reversed(range(host_count)):
partitioner = self.get_partitioner(
i,
host_count,
num_partitions,
mesh_axes=jax.tree_map(lambda x: None, self.default_mesh_axes)
if disable_partitioning else None)
data_layout = partitioner.get_data_layout()
num_ds_shards = data_layout.num_shards
ds_shard_id = data_layout.shard_id
chunk_id_for_shard = partitioner.get_local_chunk_info(
jnp.ones((num_ds_shards,)), ['data']).replica_id
bias_chunk = partitioner.get_local_chunk_info(params['bias'].shape,
('model', None))
kernel_chunk = partitioner.get_local_chunk_info(params['kernel'].shape,
(None, 'model'))
ds_iter = iter(self.ds)
# pylint:disable=cell-var-from-loop
def _save_ckpt(checkpointer):
# Set the checkpoint so that the next value on restore will be the
# replica set id.
for _ in range(ds_shard_id):
next(ds_iter)
train_state = make_train_state(
step=step,
params={
'bias': params['bias'][bias_chunk.slice],
'kernel': params['kernel'][kernel_chunk.slice]
},
param_states=param_states,
flax_optimizer_def=optimizer_def)
checkpointer.save(train_state)
# pylint:enable=cell-var-from-loop
self.call_host_checkpointer(i, host_count, partitioner, _save_ckpt,
save_dtype,
ds_iter if checkpoint_dataset else None)
if disable_partitioning:
continue
# Read the current TensorStore.
if i == 0:
# Host 0 moves the files.
bias_tspec['kvstore']['path'] = (
bias_tspec['kvstore']['path'].replace('.tmp-0', ''))
kernel_tspec['kvstore']['path'] = (
kernel_tspec['kvstore']['path'].replace('.tmp-0', ''))
if checkpoint_dataset:
ckpt_dir = f'{self.tmp_dir}/checkpoint_{step}'
if i != 0:
ckpt_dir += '.tmp-0'
ds_ckpt_glob = gfile.glob(ckpt_dir + '/train_ds-' +
f'{ds_shard_id:03}-of-{num_ds_shards:03}*')
if chunk_id_for_shard == 0:
self.assertLen(ds_ckpt_glob, 2)
else:
self.assertEmpty(ds_ckpt_glob)
# only replica_id=0 is saved for each array chunk
if bias_chunk.replica_id == 0:
current_bias = ts.open(bias_tspec).result().read().result().view(
save_dtype)
expected_bias[bias_chunk.slice] = (params['bias'][bias_chunk.slice])
np.testing.assert_equal(current_bias, expected_bias)
if kernel_chunk.replica_id == 0:
current_kernel = ts.open(kernel_tspec).result().read().result().view(
save_dtype)
expected_kernel[kernel_chunk.slice] = (
params['kernel'][kernel_chunk.slice])
np.testing.assert_equal(current_kernel, expected_kernel)
with gfile.GFile(f'{self.tmp_dir}/checkpoint_{step}/checkpoint', 'rb') as f:
ckpt_contents = serialization.msgpack_restore(f.read())
self.assertEqual(ckpt_contents['version'], checkpoints.VERSION)
jax.tree_multimap(np.testing.assert_allclose,
ckpt_contents['optimizer']['state']['param_states'],
param_states)
self.assertEqual(ckpt_contents['optimizer']['state']['step'].dtype,
np.int32)
if disable_partitioning:
# Parameters should also be in the msgpack checkpoint file.
jax.tree_multimap(
np.testing.assert_allclose, ckpt_contents['optimizer']['target'],
jax.tree_map(lambda arr: arr.astype(save_dtype), params))
# Jax tree maps ignore Nones so actually check this value is None
if multi_optimizer:
self.assertIsNone(
ckpt_contents['optimizer']['state']['param_states']['kernel'])
# (host_count, num_partitions)
TOPOLOGIES = [
(1, 1), # 1 host, 1 partition
(1, 2), # 1 host, 2 partitions
(2, 1), # 2 hosts, 1 partition
(2, 2), # 2 hosts, 2 partitions
(4, 4), # 4 hosts, 4 partitions
(4, 1), # 4 hosts, 1 partition
(4, 2), # 4 hosts, 2 partitions
(8, 2), # 8 hosts, 2 partitions
]
DTYPES = [
jnp.int32, jnp.float32, jnp.bfloat16, jnp.uint32, jnp.int64, jnp.float64
]
@parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES))
def test_save_restore(self, save_topology, restore_topology):
self.validate_save(*save_topology)
self.validate_restore(*restore_topology)
@parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES))
def test_save_restore_lazy(self, save_topology, restore_topology):
self.validate_save(*save_topology)
self.validate_restore(*restore_topology, lazy_parameters=True)
@parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES))
def test_save_multioptimizer_restore(self, save_topology, restore_topology):
self.validate_save(*save_topology)
self.validate_multioptimizer_restore(*restore_topology)
@parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES))
def test_multioptimizer_save_multioptimizer_restore(self, save_topology,
restore_topology):
self.validate_save(*save_topology, multi_optimizer=True)
self.validate_multioptimizer_restore(*restore_topology)
def test_load_t5x_checkpoint(self):
self.validate_save(1, 1)
ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir)
jax.tree_multimap(np.testing.assert_array_equal,
self.train_state.state_dict(), ckpt)
def test_load_t5x_checkpoint_of_multioptimizer(self):
self.validate_save(1, 1, multi_optimizer=True)
ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir)
jax.tree_multimap(np.testing.assert_array_equal,
self.train_state_multi_optimizer.state_dict(), ckpt)
# Jax tree maps ignore Nones so actually check this value is None
self.assertIsNone(ckpt['state']['param_states']['kernel'])
def test_load_t5x_checkpoint_lazy(self):
self.validate_save(1, 1)
ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir)
lazy_ckpt = checkpoints.load_t5x_checkpoint(
self.tmp_dir, lazy_parameters=True)
lazy_loaded_ckpt = jax.tree_map(lambda x: x.get(), lazy_ckpt)
jax.tree_multimap(np.testing.assert_array_equal, ckpt, lazy_loaded_ckpt)
def test_load_t5x_checkpoint_of_multioptimizer_lazy(self):
self.validate_save(1, 1, multi_optimizer=True)
ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir)
lazy_ckpt = checkpoints.load_t5x_checkpoint(
self.tmp_dir, lazy_parameters=True)
lazy_loaded_ckpt = jax.tree_map(lambda x: x.get(), lazy_ckpt)
jax.tree_multimap(np.testing.assert_array_equal, ckpt, lazy_loaded_ckpt)
# Jax tree maps ignore Nones so actually check this value is None
self.assertIsNone(lazy_loaded_ckpt['state']['param_states']['kernel'])
@parameterized.parameters(TOPOLOGIES)
def test_save_restore_dataset(self, *topology):
# Note that we must use the same number of replica sets on save/restore.
self.validate_save(*topology, checkpoint_dataset=True)
self.validate_restore(*topology, checkpoint_dataset=True)
@parameterized.parameters(itertools.product(DTYPES, DTYPES))
def test_save_as_type(self, save_dtype, restore_dtype):
self.validate_save(1, 1, save_dtype=save_dtype)
self.validate_restore(1, 1, expected_restore_dtype=restore_dtype)
@parameterized.parameters(TOPOLOGIES)
def test_reload_wrong_shape(self, *restore_topology):
self.validate_save(1, 1)
self.train_state = make_train_state(
step=np.int32(42),
params={
'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
'kernel': np.arange(32, dtype=np.float32).reshape((4, 8))
},
param_states={
'bias': np.int32(1),
'kernel': np.array([1, 2])
})
with self.assertRaisesWithLiteralMatch(
ValueError,
'Shape of `target/kernel` in checkpoint (2, 16) does not match '
'expected (4, 8).'):
self.validate_restore(*restore_topology)
@parameterized.parameters(TOPOLOGIES)
def test_save_partitioned_restore_non_partitioned(self, *restore_topology):
# Save with default partitioning.
self.validate_save(2, 2)
# Restore without partitioning.
self.validate_restore(*restore_topology, disable_partitioning=True)
@parameterized.parameters(TOPOLOGIES)
def test_save_non_partitioned_restore_partitioned(self, *restore_topology):
# Save without partitioning.
self.validate_save(2, 1, disable_partitioning=True)
# Restore with partitioning.
self.validate_restore(*restore_topology)
@parameterized.parameters(TOPOLOGIES)
def test_save_non_partitioned_restore_non_partitioned(self,
*restore_topology):
# Save without partitioning.
self.validate_save(2, 1, disable_partitioning=True)
# Restore with partitioning.
self.validate_restore(*restore_topology, disable_partitioning=True)
@mock.patch('time.time', return_value=0)
def test_keep(self, unused_mock_time):
no_partitions_partitioner = self.get_partitioner(0, 1, 1)
train_state = self.train_state
checkpointer = checkpoints.Checkpointer(
train_state, no_partitions_partitioner, self.tmp_dir, keep=2)
checkpointer.save(update_train_state_step(train_state, 42))
self.assertSequenceEqual(checkpointer.all_steps(), [42])
checkpointer.save(update_train_state_step(train_state, 43))
self.assertSequenceEqual(checkpointer.all_steps(), [42, 43])
checkpointer.save(update_train_state_step(train_state, 44))
self.assertSequenceEqual(checkpointer.all_steps(), [43, 44])
checkpointer.keep = 1
checkpointer.save(update_train_state_step(train_state, 45))
self.assertSequenceEqual(checkpointer.all_steps(), [45])
checkpointer.keep = 3
checkpointer.save(update_train_state_step(train_state, 46))
self.assertSequenceEqual(checkpointer.all_steps(), [45, 46])
@mock.patch('time.time', return_value=0)
def test_keep_pinned(self, unused_mock_time):
no_partitions_partitioner = self.get_partitioner(0, 1, 1)
train_state = self.train_state
checkpointer = checkpoints.Checkpointer(
train_state, no_partitions_partitioner, self.tmp_dir, keep=1)
checkpointer.save(update_train_state_step(train_state, 42))
self.assertSequenceEqual(checkpointer.all_steps(), [42])
# Mark the checkpoint as pinned by creating the ALWAYS KEEP file.
ckpt_dir = self.checkpoints_dir.mkdir(f'checkpoint_{42}')
ckpt_dir.create_file('PINNED')
checkpointer.save(update_train_state_step(train_state, 43))
# Assert both the pinned and the most recent checkpoints are saved.
self.assertSequenceEqual(checkpointer.all_steps(), [42, 43])
checkpointer.save(update_train_state_step(train_state, 44))
# Assert the non-pinned checkpoint gets deleted, but the pinned and the most
# recent one are still saved.
self.assertSequenceEqual(checkpointer.all_steps(), [42, 44])
@mock.patch('time.time', return_value=0)
def test_keep_dataset_checkpoints(self, unused_mock_time):
no_partitions_partitioner = self.get_partitioner(0, 1, 1)
train_state = self.train_state
dataset_iterator = iter(tf.data.Dataset.range(10))
checkpointer = checkpoints.Checkpointer(
train_state,
no_partitions_partitioner,
self.tmp_dir,
dataset_iterator=dataset_iterator,
keep=2,
keep_dataset_checkpoints=1)
checkpointer.save(update_train_state_step(train_state, 42))
self.assertSequenceEqual(checkpointer.all_steps(), [42])
self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), [42])
checkpointer.save(update_train_state_step(train_state, 43))
self.assertSequenceEqual(checkpointer.all_steps(), [42, 43])
self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), [43])
checkpointer.save(update_train_state_step(train_state, 44))
self.assertSequenceEqual(checkpointer.all_steps(), [43, 44])
self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), [44])
checkpointer.keep = 1
checkpointer.save(update_train_state_step(train_state, 45))
self.assertSequenceEqual(checkpointer.all_steps(), [45])
self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), [45])
checkpointer.keep = 3
checkpointer.save(update_train_state_step(train_state, 46))
self.assertSequenceEqual(checkpointer.all_steps(), [45, 46])
self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), [46])
@mock.patch('time.time', return_value=0)
def test_keep_dataset_checkpoints_pinned(self, unused_mock_time):
no_partitions_partitioner = self.get_partitioner(0, 1, 1)
train_state = self.train_state
dataset_iterator = iter(tf.data.Dataset.range(10))
checkpointer = checkpoints.Checkpointer(
train_state,
no_partitions_partitioner,
self.tmp_dir,
dataset_iterator=dataset_iterator,
keep=1,
keep_dataset_checkpoints=1)
checkpointer.save(update_train_state_step(train_state, 42))
self.assertSequenceEqual(checkpointer.all_steps(), [42])
# Mark the checkpoint as pinned by creating the ALWAYS KEEP file.
ckpt_dir = self.checkpoints_dir.mkdir(f'checkpoint_{42}')
ckpt_dir.create_file('PINNED')
checkpointer.save(update_train_state_step(train_state, 43))
# Assert both the pinned and the most recent checkpoints are saved.
self.assertSequenceEqual(checkpointer.all_steps(), [42, 43])
self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(),
[42, 43])
checkpointer.save(update_train_state_step(train_state, 44))
# Assert the non-pinned checkpoint gets deleted, but the pinned and the most
# recent one are still saved.
self.assertSequenceEqual(checkpointer.all_steps(), [42, 44])
self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(),
[42, 44])
@mock.patch('time.time', return_value=0)
def test_keep_with_save_best_checkpointer(self, unused_mock_time):
no_partitions_partitioner = self.get_partitioner(0, 1, 1)
train_state = self.train_state
checkpointer = checkpoints.SaveBestCheckpointer(
train_state,
no_partitions_partitioner,
self.tmp_dir,
keep=2,
metric_name_to_monitor='train/accuracy',
metric_mode='max',
keep_checkpoints_without_metrics=False)
# Test that without a valid set of metrics deletion falls back to oldest
# step (since keep_checkpoints_without_metrics is set to False).
checkpointer.save(update_train_state_step(train_state, 41))
self.assertSequenceEqual(checkpointer.all_steps(), [41])
checkpointer.save(update_train_state_step(train_state, 42))
self.assertSequenceEqual(checkpointer.all_steps(), [41, 42])
checkpointer.save(update_train_state_step(train_state, 43))
self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43])
checkpointer.save(update_train_state_step(train_state, 44))
self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 44])
# Now create some metrics for steps 42, 43 and 44.
summary_writer = tensorboard.SummaryWriter(
os.path.join(self.tmp_dir, 'train'))
summary_writer.scalar('accuracy', 0.9, 42)
summary_writer.scalar('accuracy', 0.8, 43)
summary_writer.scalar('accuracy', 0.7, 44)
# Verify that both the newest (without a metrics) and best accuracy
# checkpoints are kept.
checkpointer.save(update_train_state_step(train_state, 45))
self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 45])
# Change mode to `min` and check that the checkpoints with highest accuracy
# are removed.
checkpointer._metric_mode = 'min'
# Add metrics to newly created checkpoint as well as a new checkpoint.
summary_writer.scalar('accuracy', 0.95, 45)
checkpointer.save(update_train_state_step(train_state, 46))
summary_writer.scalar('accuracy', 0.99, 46)
checkpointer.save(update_train_state_step(train_state, 47))
self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 47])
@mock.patch('time.time', return_value=0)
def test_keep_pinned_save_best_checkpointer(self, unused_mock_time):
no_partitions_partitioner = self.get_partitioner(0, 1, 1)
train_state = self.train_state
checkpointer = checkpoints.SaveBestCheckpointer(
train_state,
no_partitions_partitioner,
self.tmp_dir,
keep=2,
metric_name_to_monitor='train/accuracy',
metric_mode='max',
keep_checkpoints_without_metrics=False)
summary_writer = tensorboard.SummaryWriter(
os.path.join(self.tmp_dir, 'train'))
checkpointer.save(update_train_state_step(train_state, 42))
summary_writer.scalar('accuracy', 0.9, 42)
checkpointer.save(update_train_state_step(train_state, 43))
summary_writer.scalar('accuracy', 0.7, 43)
checkpointer.save(update_train_state_step(train_state, 44))
summary_writer.scalar('accuracy', 0.8, 44)
self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 44])
# Mark checkpoint 43 as always keep.
ckpt_dir = self.checkpoints_dir.mkdir(f'checkpoint_{43}')
always_keep_ckpt_43 = ckpt_dir.create_file('PINNED')
# Verify that the pinned checkpoint 43 is always saved even though it does
# not have the best metrics, and keep = 2.
checkpointer.save(update_train_state_step(train_state, 45))
self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 44, 45])
checkpointer.save(update_train_state_step(train_state, 46))
summary_writer.scalar('accuracy', 0.6, 46)
# Remove the ALWAYS KEEP file for checkpoint 43.
gfile.rmtree(always_keep_ckpt_43.full_path)
# Checkpoint 43 should get deleted in the next update since it is not
# pinned and does not have the best metrics.
checkpointer.save(update_train_state_step(train_state, 47))
self.assertSequenceEqual(checkpointer.all_steps(), [42, 44, 47])
@mock.patch('time.time', return_value=0)
def test_keep_pinned_save_best_checkpointer_missing_metrics(
self, unused_mock_time):
"""Test for `keep_checkpoints_without_metrics` behavior."""
no_partitions_partitioner = self.get_partitioner(0, 1, 1)
train_state = self.train_state
# Use SaveBestCheckpointer with default keep_checkpoints_without_metrics.
checkpointer = checkpoints.SaveBestCheckpointer(
train_state,
no_partitions_partitioner,
self.tmp_dir,
keep=1,
metric_name_to_monitor='train/accuracy',
metric_mode='max')
# Pre-create metrics for only some of the steps.
summary_writer = tensorboard.SummaryWriter(
os.path.join(self.tmp_dir, 'train'))
summary_writer.scalar('accuracy', 0.5, 43)
summary_writer.scalar('accuracy', 0.4, 44)
summary_writer.scalar('accuracy', 0.8, 45)
summary_writer.scalar('accuracy', 0.3, 46)
# Verify that we keep checkpoints for 41 and 42 even without metrics.
checkpointer.save(update_train_state_step(train_state, 41))
checkpointer.save(update_train_state_step(train_state, 42))
checkpointer.save(update_train_state_step(train_state, 43))
self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43])
# Mark 41 and 43 checkpoints as pinned / to not be removed.
ckpt_dir_41 = self.checkpoints_dir.mkdir(f'checkpoint_{41}')
ckpt_dir_41.create_file('PINNED')
ckpt_dir_43 = self.checkpoints_dir.mkdir(f'checkpoint_{43}')
ckpt_dir_43.create_file('PINNED')
# Checkpoints 41 and 43 should always be kept because they are pinned.
checkpointer.save(update_train_state_step(train_state, 44))
self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 44])
# Checkpoint 44 should get deleted on next save. 43 is saved inspite of
# it's low accuracy because it is pinned.
checkpointer.save(update_train_state_step(train_state, 45))
self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 45])
@mock.patch('time.time', return_value=0)
def test_save_best_checkpointer_from_restart(self, unused_mock_time):
"""Emulate restart/preempt condition."""
no_partitions_partitioner = self.get_partitioner(0, 1, 1)
train_state = self.train_state
# First, create a checkpointer that saves all checkpoints.
checkpointer = checkpoints.Checkpointer(
train_state, no_partitions_partitioner, self.tmp_dir, keep=None)
# Create a series of checkpoints. Create many checkpoints to stress test
# event collection (some methods employ lossy/sampling collection).
for i in range(100):
checkpointer.save(update_train_state_step(train_state, i))
self.assertSequenceEqual(checkpointer.all_steps(), list(range(100)))
# Now create some metrics for all steps, with high metrics on specific
# steps.
summary_writer = tensorboard.SummaryWriter(
os.path.join(self.tmp_dir, 'train'))
for i in range(100):
if i in (42, 53):
summary_writer.scalar('accuracy', i * 0.01, i)
else:
summary_writer.scalar('accuracy', i * 0.001, i)
# Replace checkpointer with SaveBest variant.
checkpointer = checkpoints.SaveBestCheckpointer(
train_state,
no_partitions_partitioner,
self.tmp_dir,
keep=2,
metric_name_to_monitor='train/accuracy',
metric_mode='max')
# Verify that pre-existing metrics are read and the appropriate checkpoints
# are deleted.
checkpointer.save(update_train_state_step(train_state, 101))
self.assertSequenceEqual(checkpointer.all_steps(), [42, 53, 101])
def test_save_best_checkpointer_force_keep_period(self):
no_partitions_partitioner = self.get_partitioner(0, 1, 1)
train_state = self.train_state
checkpointer = checkpoints.SaveBestCheckpointer(
train_state,
no_partitions_partitioner,
self.tmp_dir,
keep=2,
metric_name_to_monitor='train/accuracy',
metric_mode='max',
keep_checkpoints_without_metrics=False,
force_keep_period=3)
summary_writer = tensorboard.SummaryWriter(
os.path.join(self.tmp_dir, 'train'))
# save checkpoints 0..9 with increasing accuracy
dict_actual_steps = {}
for c in range(10):
checkpointer.save(update_train_state_step(train_state, c))
summary_writer.scalar('accuracy', c / 100, c)
dict_actual_steps[c] = checkpointer.all_steps()
# Check when the last step=8 is not divisible by the keep_period=3
actual_steps_8 = dict_actual_steps[8]
expected_steps_8 = [0, 3, 5, 6, 7, 8]
self.assertSequenceEqual(actual_steps_8, expected_steps_8)
# Check when the last step=9 is divisible by the keep_period=3
actual_steps_9 = dict_actual_steps[9]
expected_steps_9 = [0, 3, 6, 7, 8, 9]
self.assertSequenceEqual(actual_steps_9, expected_steps_9)
@mock.patch('time.time', return_value=0)
def test_save_best_checkpointer_missing_metrics(self, unused_mock_time):
"""Test for `keep_checkpoints_without_metrics` behavior."""
no_partitions_partitioner = self.get_partitioner(0, 1, 1)
train_state = self.train_state
# Replace checkpointer with SaveBest variant.
checkpointer = checkpoints.SaveBestCheckpointer(
train_state,
no_partitions_partitioner,
self.tmp_dir,
keep=1,
metric_name_to_monitor='train/accuracy',
metric_mode='max')
# Pre-create metrics for only some of the steps.
summary_writer = tensorboard.SummaryWriter(
os.path.join(self.tmp_dir, 'train'))
summary_writer.scalar('accuracy', 0.6, 43)
summary_writer.scalar('accuracy', 0.5, 44)
summary_writer.scalar('accuracy', 0.4, 45)
# Verify that we always keep checkpoints for 41 and 42 (no metrics) and that
# number to keep applies to other checkpoints.
checkpointer.save(update_train_state_step(train_state, 41))
self.assertSequenceEqual(checkpointer.all_steps(), [41])
checkpointer.save(update_train_state_step(train_state, 42))
self.assertSequenceEqual(checkpointer.all_steps(), [41, 42])
checkpointer.save(update_train_state_step(train_state, 43))
self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43])
checkpointer.save(update_train_state_step(train_state, 44))
self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 44])
# Checkpoint 44 should get deleted on next save.
checkpointer.save(update_train_state_step(train_state, 45))
self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 45])
# When switching keep_checkpoints_without_metrics to False, we should see
# checkpoints 41 and 42 also be deleted.
checkpointer._keep_checkpoints_without_metrics = False
checkpointer.save(update_train_state_step(train_state, 46))
self.assertSequenceEqual(checkpointer.all_steps(), [43, 46])
def test_assignment_map(self):
self.validate_save(1, 1)
# Change optimizer
optimizer = optimizers.Optimizer(
optimizers.sgd(0.1),
state=optimizers.OptimizerState(
step=np.int32(42),
param_states={
'bias': np.int32(1),
'kernel': np.array([1, 2], np.uint8)
}),
target={
'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
'layer1': {
'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
},
'layer2': {
'bias': np.arange(32, dtype=np.float32).reshape((2, 16)),
'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
}
})
self.train_state = FlaxOptimTrainState(optimizer)
actual_train_state = self.call_host_checkpointer(
0,
1,
self.get_partitioner(
0, 1, 1, mesh_axes=jax.tree_map(lambda x: None, self.train_state)),
lambda c: c.restore( # pylint:disable=g-long-lambda
step=42,
state_transformation_fns=[
functools.partial(
state_utils.apply_assignment_map,
assignment_map=[('target/layer2/bias', 'target/kernel'),
('target/layer\\d/(.*)', 'target/\\1')])
]),
np.float32,
None)
self.assertEqual(actual_train_state.step, 42)
self.assertEqual(actual_train_state._optimizer.optimizer_def,
self.train_state._optimizer.optimizer_def)
jax.tree_multimap(np.testing.assert_array_equal,
actual_train_state.param_states,
self.train_state.param_states)
jax.tree_multimap(np.testing.assert_array_equal, actual_train_state.params,
self.train_state.params)
def test_assignment_map_unused(self):
self.validate_save(1, 1)
with self.assertRaisesWithLiteralMatch(
ValueError,
"Unused patterns in `assignment_map`: {'target/layer\\d/(.*)'}"):
self.call_host_checkpointer(
0,
1,
self.get_partitioner(0, 1, 1),
lambda c: c.restore( # pylint:disable=g-long-lambda
step=42,
state_transformation_fns=[
functools.partial(
state_utils.apply_assignment_map,
assignment_map=[('target/layer\\d/(.*)', 'target/\\1')])
]),
np.float32,
None)
def test_assignment_map_noexists(self):
self.validate_save(1, 1)
with self.assertRaisesWithLiteralMatch(
ValueError,
"Parameter 'target/layer/bias' does not exist in restore checkpoint. "
"Must be one of: ['state/param_states/bias', "
"'state/param_states/kernel', 'state/step', 'target/bias', "
"'target/kernel']"):
self.call_host_checkpointer(
0,
1,
self.get_partitioner(0, 1, 1),
lambda c: c.restore( # pylint:disable=g-long-lambda
step=42,
state_transformation_fns=[
functools.partial(
state_utils.apply_assignment_map,
assignment_map=[('target/(.*)', 'target/layer/\\1')])
]),
np.float32,
None)
def test_assignment_map_partial_restore(self):
self.validate_save(1, 1)
# Change optimizer
optimizer = optimizers.Optimizer(
optimizers.sgd(0.1),
state=optimizers.OptimizerState(
step=np.int32(42),
param_states={
'bias': np.int32(1),
'kernel': np.array([1, 2], np.uint8)
}),
target={
'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
'layer1': {
'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
},
'layer2': {
'bias': np.arange(32, dtype=np.float32).reshape((2, 16)),
'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
}
})
self.train_state = FlaxOptimTrainState(optimizer)
actual_train_state = self.call_host_checkpointer(
0,
1,
self.get_partitioner(
0, 1, 1, mesh_axes=jax.tree_map(lambda x: None, self.train_state)),
lambda c: c.restore( # pylint:disable=g-long-lambda
step=42,
state_transformation_fns=[
functools.partial(
state_utils.apply_assignment_map,
assignment_map=[
# Restore only the target kernels.
(r'target/layer(\d+)/kernel', r'target/kernel'),
(r'target.*bias', None),
(r'state.*', None)])
],
fallback_state={
# Initialize biases and optimizer state "from scratch"
'target': {
'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
'layer1': {
'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
},
'layer2': {
'bias': np.arange(32, dtype=np.float32).reshape((2, 16)),
}
},
'state': {
'step': 1337, # Note: original optimizer is step=42
'param_states': {
'bias': 1,
'kernel': np.array([1, 2], np.uint8)
}
}
}),
np.float32,
None)
self.assertEqual(actual_train_state._optimizer.optimizer_def,
self.train_state._optimizer.optimizer_def)
self.assertEqual(actual_train_state.step, 1337) # note: from-scratch
jax.tree_multimap(np.testing.assert_array_equal,
actual_train_state.param_states,
self.train_state.param_states)
jax.tree_multimap(np.testing.assert_array_equal, actual_train_state.params,
self.train_state.params)
def verify_restore_checkpoint_from_path(
self,
path,
model,
decoder_only=False,
partitioner_class=partitioning.PjitPartitioner):
partitioner = partitioner_class(num_partitions=1)
input_features = {'decoder_input_tokens': tf.zeros([2, 8])}
if not decoder_only:
input_features['encoder_input_tokens'] = tf.zeros([2, 8])
train_ds = tf.data.Dataset.from_tensors(input_features)
train_state_initializer = utils.TrainStateInitializer(
optimizer_def=model.optimizer_def,
init_fn=model.get_initial_variables,
input_shapes={k: v.shape for k, v in train_ds.element_spec.items()},
partitioner=partitioner)
restored = list(
train_state_initializer.from_checkpoints(
[utils.RestoreCheckpointConfig(mode='specific', path=path)]))
self.assertLen(restored, 1)
return restored[0]
def test_checkpointer_in_threaded_env(self):
"""Tests use of asyncio in checkpointer works with non-main threads."""
executor = concurrent.futures.thread.ThreadPoolExecutor(max_workers=1)
save = executor.submit(self.validate_save, 1, 1)
save.result()
restore = executor.submit(self.validate_restore, 1, 1)
restore.result()
def test_find_checkpoint(self):
# `model_dir` with no step
self.assertEqual(
checkpoints.find_checkpoint(self.fake_checkpoints),
os.path.join(self.fake_checkpoints, f'checkpoint_{self.steps[-1]}',
'checkpoint'))
# `model_dir` with step
step = 100
self.assertEqual(
checkpoints.find_checkpoint(self.fake_checkpoints, step),
os.path.join(self.fake_checkpoints, f'checkpoint_{step}', 'checkpoint'))
# checkpoint_dir
self.assertEqual(
checkpoints.find_checkpoint(
os.path.join(self.fake_checkpoints, f'checkpoint_{step}')),
os.path.join(self.fake_checkpoints, f'checkpoint_{step}', 'checkpoint'))
# checkpoint_dir with step
with self.assertRaises(ValueError):
_ = checkpoints.find_checkpoint(
os.path.join(self.fake_checkpoints, f'checkpoint_{step}'), 1000),
# checkpoint_file
path = os.path.join(self.fake_checkpoints, f'checkpoint_{step}',
'checkpoint')
self.assertEqual(checkpoints.find_checkpoint(path), path)
# checkpoint_file with step
self.assertEqual(checkpoints.find_checkpoint(path, 1000), path)
# Error with step
with self.assertRaises(ValueError):
checkpoints.find_checkpoint(self.fake_checkpoints, 1000)
# Error
with self.assertRaises(ValueError):
checkpoints.find_checkpoint(
os.path.join(self.fake_checkpoints, 'checkpoint'))
def test_restore_tf_as_t5x(self):
checkpoint_path = os.path.join(TESTDATA, 'mtf_tiny_t5')
partitioner = self.get_partitioner(0, 1, 1)
with self.assertRaisesRegex(
ValueError,
'Attempting to restore a TensorFlow checkpoint as a native T5X '
'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: .*'):
self.call_host_checkpointer(0, 1, partitioner,
lambda c: c.restore(path=checkpoint_path),
np.float32, None)
def test_restore_from_invalid_path(self):
with self.assertRaisesRegex(ValueError,
r'Path is not a valid T5X checkpoint: .*'):
self.verify_restore_checkpoint_from_path(TESTDATA,
test_utils.get_t5_test_model())
with self.assertRaisesRegex(ValueError,
r'Path is not a valid T5X checkpoint: .*'):
self.verify_restore_checkpoint_from_path(
os.path.join(TESTDATA, 'checkpoint'), test_utils.get_t5_test_model())
def test_save_lazy_optimizer(self):
# Call save one to get the parameters onto disk
self.validate_save(1, 1)
# Load the parameters in a lazy way
partitioner = self.get_partitioner(0, 1, 1, params_on_devices=False)
step = 42
train_state = self.call_host_checkpointer(
0,
1,
partitioner,
lambda c: c.restore( # pylint: disable=g-long-lambda
step=step, lazy_parameters=True),
np.float32,
None)
# Increment the step so we can save it
new_step = train_state.step.get() + 1
state_dict = train_state.state_dict()
state_dict['state']['step'] = new_step
train_state = train_state.restore_state(state_dict)
# Save the train state that is made of lazy parameters.
self.call_host_checkpointer(
0, 1, partitioner,
lambda c: c.save(train_state=train_state, concurrent_gb=2), np.float32,
None)
# Load what we just saved to inspect values
loaded_train_state = checkpoints.load_t5x_checkpoint(
self.tmp_dir, step=new_step)
# Make sure the parameters are the same.
train_state = jax.tree_map(
lambda x: x.get() # pylint: disable=g-long-lambda
if isinstance(x, LazyArray) else x,
train_state)
jax.tree_multimap(np.testing.assert_allclose, train_state.state_dict(),
loaded_train_state)
def test_update_ts_from_gfile_to_gcs(self):
ckpt_contents = {
'version': 3,
'optimizer': {
'target': {
'unsharded_param': np.ones((5, 5), dtype=np.int32),
'sharded_param': {
'driver': 'zarr',
'dtype': 'float32',
'kvstore': {
'driver': 'file',
'path': 'target.sharded_param'
},
'metadata': {
'chunks': [768, 768],
'compressor': {
'id': 'gzip',
'level': 1
},
'shape': [768, 768]
}
}
}
}
}
expected = {
'version': 3,
'optimizer': {
'target': {
# np.ndarray should not change
'unsharded_param': np.ones((5, 5), dtype=np.int32),
'sharded_param': {
'driver': 'zarr',
'dtype': 'float32',
'kvstore': {
'bucket': 't5x-dummy-bucket',
'driver': 'gcs',
'path': 'target.sharded_param'
},
'metadata': {
'chunks': [768, 768],
'compressor': {
'id': 'gzip',
'level': 1
},
'shape': [768, 768]
}
}
}
}
}
actual = checkpoints._maybe_update_ts_from_file_to_gcs(ckpt_contents)
jax.tree_multimap(np.testing.assert_array_equal, actual, expected)
def test_update_ts_from_gcs_to_file(self):
ckpt_contents = {
'version': 3,
'optimizer': {
'target': {
# np.ndarray should not change
'unsharded_param': np.ones((5, 5), dtype=np.int32),
'sharded_param': {
'driver': 'zarr',
'dtype': 'float32',
'kvstore': {
'bucket': 't5x-dummy-bucket',
'driver': 'gcs',
'path': 'target.sharded_param'
},
'metadata': {
'chunks': [768, 768],
'compressor': {
'id': 'gzip',
'level': 1
},
'shape': [768, 768]
},
}
}
}
}
driver = 'file'
expected = {
'version': 3,
'optimizer': {
'target': {
'unsharded_param': np.ones((5, 5), dtype=np.int32),
'sharded_param': {
'driver': 'zarr',
'dtype': 'float32',
'kvstore': {
'driver': driver,
'path': 'target.sharded_param'
},
'metadata': {
'chunks': [768, 768],
'compressor': {
'id': 'gzip',
'level': 1
},
'shape': [768, 768]
}
}
}
}
}
actual = checkpoints._maybe_update_ts_from_gcs_to_file(ckpt_contents)
jax.tree_multimap(np.testing.assert_array_equal, actual, expected)
def assert_update_ts_path_from_relative_to_absolute(self, ts_spec_dict,
expected, ckpt_dir):
"""Tests that `ts_spec_dict` gets updated with `ckpt_dir` to `expected`."""
# Test with normalization (corresponds to tensorstore>=0.1.14)
normalized_ts_spec_dict = ts.Spec(ts_spec_dict).to_json()
checkpoints._update_ts_path_from_relative_to_absolute(
ckpt_dir, normalized_ts_spec_dict)
normalized_ts_spec_dict = ts.Spec(normalized_ts_spec_dict).to_json()
normalized_expected = ts.Spec(expected).to_json()
jax.tree_multimap(np.testing.assert_array_equal, normalized_ts_spec_dict,
normalized_expected)
# Test without normalization (corresponds to tensorstore<0.1.14)
checkpoints._update_ts_path_from_relative_to_absolute(
ckpt_dir, ts_spec_dict)
jax.tree_multimap(np.testing.assert_array_equal, ts_spec_dict, expected)
def test_update_ts_path_from_relative_to_absolute_gfile(self):
ts_spec_dict = {
'driver': 'zarr',
'dtype': 'float32',
'kvstore': {
'driver': 'file',
'path': 'target.encoder.layers_0.attention.query.kernel'
},
'metadata': {
'chunks': [768, 768],
'compressor': {
'id': 'gzip',
'level': 1
},
'shape': [768, 768]
}
}
expected = {
'driver': 'zarr',
'dtype': 'float32',
'kvstore': {
'driver': 'file',
# Path becomes absolute.
'path': '/dir1/dir2/target.encoder.layers_0.attention.query.kernel'
},
'metadata': {
'chunks': [768, 768],
'compressor': {
'id': 'gzip',
'level': 1
},
'shape': [768, 768]
}
}
ckpt_dir = '/dir1/dir2'
self.assert_update_ts_path_from_relative_to_absolute(
ts_spec_dict, expected, ckpt_dir)
def test_update_ts_path_from_relative_to_absolute_gcs(self):
ts_spec_dict = {
'driver': 'zarr',
'dtype': 'float32',
'kvstore': {
'bucket': 't5x-dummy-bucket',
'driver': 'gcs'
},
'metadata': {
'chunks': [768, 768],
'compressor': {
'id': 'gzip',
'level': 1
},
'shape': [768, 768]
},
'path': 'target.encoder.layers_0.attention.query.kernel',
'transform': {
'input_exclusive_max': [[768], [768]],
'input_inclusive_min': [0, 0]
}
}
expected = {
'driver': 'zarr',
'dtype': 'float32',
'kvstore': {
'bucket': 'test-bucket', # bucket should be changed.
'driver': 'gcs'
},
'metadata': {
'chunks': [768, 768],
'compressor': {
'id': 'gzip',
'level': 1
},
'shape': [768, 768]
},
# Path becomes absolute without the "gs://bucket" portion stripped.
'path': 'dir1/dir2/target.encoder.layers_0.attention.query.kernel',
'transform': {
'input_exclusive_max': [[768], [768]],
'input_inclusive_min': [0, 0]
}
}
ckpt_dir = 'gs://test-bucket/dir1/dir2'
self.assert_update_ts_path_from_relative_to_absolute(
ts_spec_dict, expected, ckpt_dir)
def test_restore_tf_checkpoint(self):
self.verify_restore_checkpoint_from_path(
os.path.join(TESTDATA, 'mtf_tiny_t5/model.ckpt-0'),
test_utils.get_t5_test_model(
emb_dim=32, head_dim=64, num_heads=2, mlp_dim=64))
def test_restore_tf_checkpoint_wrong_config(self):
with self.assertRaisesRegex(ValueError, r'Variable .* has shape .* != .*'):
self.verify_restore_checkpoint_from_path(
os.path.join(TESTDATA, 'mtf_tiny_t5/model.ckpt-0'),
test_utils.get_t5_test_model())
def test_convert_tf_checkpoint(self):
checkpoint_path = os.path.join(TESTDATA, 'mtf_tiny_t5/model.ckpt-0')
# Minimal setup to create an optimizer with the matching config.
model = test_utils.get_t5_test_model(
emb_dim=32, head_dim=64, num_heads=2, mlp_dim=64)
partitioner = partitioning.PjitPartitioner(num_partitions=1)
def initialize_params_fn(rng):
initial_variables = model.get_initial_variables(
rng=rng,
input_shapes={
'encoder_input_tokens': (2, 512),
'decoder_input_tokens': (2, 114),
})
return FlaxOptimTrainState.create(model.optimizer_def, initial_variables)
train_state = jax.eval_shape(initialize_params_fn, jax.random.PRNGKey(0))
checkpointer = checkpoints.Checkpointer(train_state, partitioner,
self.tmp_dir)
_ = checkpointer.convert_from_tf_checkpoint(checkpoint_path)
def test_load_matched(self):
checkpoint = os.path.join(TESTDATA, 'test_t5_tiny.checkpoint_0')
train_state = self.verify_restore_checkpoint_from_path(
checkpoint, test_utils.get_t5_test_model())
state_dict = train_state._optimizer.state_dict()
ckpt = checkpoints.load_t5x_checkpoint(checkpoint)
jax.tree_multimap(np.testing.assert_array_equal, state_dict, ckpt)
if __name__ == '__main__':
absltest.main()