youtube-music-transcribe / t5x /checkpoints.py
juancopi81's picture
Add t5x and mt3 models
b100e1c
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for reading and writing sharded checkpoints.
The checkpointing utilities here can be used in two ways. The first is to use
the `Checkpointer` class. This requires having an optimizer and various
partitioning utilities setup, but allows for reading and writing of partitioned
parameters. It also allows different hosts to read different parameter
partitions in a multi-host setup, which results in much faster reads. This is
normally used during training where you have already created an optimizer based
on a config.
The second way is to use the `load_t5x_checkpoint` function. This doesn't
require an optimizer to get given up front so it is useful for things like
debugging and analysis of learned weights. However, this means that we cannot do
partitioned reads so loading will be slower than that `Checkpointer` class.
"""
import asyncio
import dataclasses
import functools
import os
import re
import subprocess
import time
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple
from absl import logging
from flax import serialization
from flax import traverse_util
import jax
import jax.config
from jax.experimental import global_device_array as gda_lib
from jax.experimental import multihost_utils
from jax.experimental.gda_serialization import serialization as gda_serialization
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
from t5x import checkpoint_importer
from t5x import checkpoint_utils
from t5x import optimizers
from t5x import partitioning
from t5x import state_utils
from t5x import train_state as train_state_lib
import tensorflow as tf
from tensorflow.io import gfile
import tensorstore as ts
import typing_extensions
from tensorboard.backend.event_processing import directory_watcher
from tensorboard.backend.event_processing import event_file_loader
from tensorboard.backend.event_processing import io_wrapper
PartitionSpec = partitioning.PartitionSpec
PyTreeDef = type(jax.tree_structure(None))
LazyArray = checkpoint_importer.LazyArray
LazyAwaitableArray = checkpoint_importer.LazyAwaitableArray
LazyThreadPoolArray = checkpoint_importer.LazyThreadPoolArray
# Version 3 is used since 2021-06-10, compared to version 2 the only change is
# that `bfloat16` arrays are written in Tensorstore using its native `bfloat16`
# support instead of casting them to `uint16`.
VERSION = 3
# Desired chunk size is 64MiB.
# This is large enough to keep CNS happy but small enough to support a wide
# range of partitionings.
_DESIRED_CHUNK_SIZE_BYTES = 64 * 1024 * 1024
# TODO(levskaya, adarob): how should we handle stacked/fused variables??
_TRAIN_DS_PREFIX = 'train_ds'
def _choose_chunk_shape(write_shape: Sequence[int],
target_elements: int) -> List[int]:
"""Chooses a chunk shape that evenly divides write_shape.
The chunk shape is chosen such that the total number of elements is less than
or equal to `target_elements`, but is otherwise as large as possible.
This uses a greedy algorithm that attempts to split the largest dimensions
first.
Args:
write_shape: Write shape for which to choose a chunk shape.
target_elements: Desired number of elements in chosen chunk shape. Must be
>= 1.
Returns:
List of length `len(write_shape)` specifying the chosen chunk shape.
"""
assert target_elements >= 1
rank = len(write_shape)
# `dim_factors[i]` is the list of divisors of `write_shape[i]`
dim_factors = [
[i for i in range(1, size + 1) if size % i == 0] for size in write_shape
]
# The current chunk shape is:
# [dim_factors[i][-1] for i in range(rank)]
def get_total_elements():
"""Returns the number of elements in the current chunk shape."""
total_elements = 1
for i in range(rank):
total_elements *= dim_factors[i][-1]
return total_elements
# Reduce the current chunk shape until the desired number of elements is
# reached.
while get_total_elements() > target_elements:
# Greedily reduce the largest dimension. This is not guaranteed to bring us
# the closest to `target_elements`, but is simple to implement and should
# work well enough.
dim_to_reduce = -1
dim_to_reduce_size = 1
for i in range(rank):
size = dim_factors[i][-1]
if size > dim_to_reduce_size:
dim_to_reduce_size = size
dim_to_reduce = i
# Can only fail to choose `dim_to_reduce` if all dimensions have size of 1.
# But that cannot happen since `target_elements >= 1`.
assert dim_to_reduce_size > 1
dim_factors[dim_to_reduce].pop()
return [dim_factors[i][-1] for i in range(rank)]
@dataclasses.dataclass
class _ParameterInfo:
"""Information needed to read/write and slice a partitioned parameter."""
# The unique parameter name.
name: str
# The shape of the parameter.
shape: Tuple[int]
# The TensoreStore Spec containing the minimal information for read/write.
ts_spec: Optional[ts.Spec]
# The LocalChunkInfo for the part of the parameter local to this host.
local_chunk_info: Optional[partitioning.LocalChunkInfo]
# PartitionSpec mesh axes
axes: Optional[partitioning.PartitionSpec] = None
orbax.checkpoint.utils.register_ts_spec_for_serialization()
def _run_future_tree(future_tree):
"""Block until all futures are resolved on this host."""
future_leaves, treedef = jax.tree_flatten(future_tree)
# TODO(adarob): Use asyncio.run in py3.7+.
loop = asyncio.get_event_loop()
leaves = loop.run_until_complete(asyncio.gather(*future_leaves))
return jax.tree_unflatten(treedef, leaves)
def all_steps(checkpoints_dir: str) -> Sequence[int]:
"""Returns list of available step numbers in ascending order."""
glob_pattern = os.path.join(checkpoints_dir, 'checkpoint_*', 'checkpoint')
checkpoint_paths = gfile.glob(glob_pattern)
re_pattern = re.compile(r'.*/checkpoint_(\d+)/checkpoint$')
matches = [re_pattern.match(ckpt) for ckpt in checkpoint_paths]
return sorted(int(match.group(1)) for match in matches if match)
def all_dataset_checkpoint_steps(checkpoints_dir: str) -> Sequence[int]:
"""Returns available dataset checkpoint step numbers in ascending order."""
glob_pattern = os.path.join(checkpoints_dir, 'checkpoint_*',
f'{_TRAIN_DS_PREFIX}-*')
train_ds_paths = gfile.glob(glob_pattern)
re_pattern = re.compile(r'.*/checkpoint_(\d+)/.*$')
matches = [re_pattern.match(path) for path in train_ds_paths]
return sorted(set(int(match.group(1)) for match in matches if match))
def latest_step(checkpoints_dir: str) -> Optional[int]:
"""Returns latest step number or None if no checkpoints exist."""
steps = all_steps(checkpoints_dir)
if not steps:
return None
return steps[-1]
def _get_local_data(x):
if isinstance(x, gda_lib.GlobalDeviceArray):
return x.local_data(0)
else:
return x
def get_checkpoint_dir(checkpoints_dir: str, step: int) -> str:
"""Returns path to a checkpoint dir given a parent directory and step."""
return os.path.join(checkpoints_dir, f'checkpoint_{step}')
def _cast(target: PyTreeDef, dtype: jnp.dtype):
"""Cast arrays in target to dtype."""
def maybe_cast(x):
if isinstance(x, (int, str)):
# Ignore common non-array types that shouldn't be cast.
return x
elif x.dtype == dtype:
return x
elif isinstance(x, jax.ShapeDtypeStruct):
return jax.ShapeDtypeStruct(x.shape, dtype)
elif isinstance(x, gda_lib.GlobalDeviceArray):
raise ValueError('GDA cast not supported.')
else:
return x.astype(dtype)
return jax.tree_map(maybe_cast, target)
def _update_ts_path_from_relative_to_absolute(
ckpt_dir: str, ts_spec_dict: MutableMapping[str, Any]):
"""Update (in-place) the path and gcs bucket (if applicable) in a TS Spec."""
# Handle `gs://` paths.
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_dir, re.DOTALL)
if m is not None:
if ts_spec_dict['kvstore']['driver'] != 'gcs':
raise ValueError(f'Incorrect TensorStore Spec. '
f'Expects kvstore driver to be "gcs" for {ckpt_dir}. '
f'Got {ts_spec_dict}')
bucket = m.group(1)
ckpt_dir = m.group(2)
ts_spec_dict['kvstore']['bucket'] = bucket
# Update the path with `ckpt_dir`
if 'path' in ts_spec_dict['kvstore']:
# tensorstore>=0.1.14 format
ts_spec_dict['kvstore']['path'] = os.path.join(
ckpt_dir, ts_spec_dict['kvstore']['path'])
elif 'path' in ts_spec_dict:
# tensorstore<0.1.14 format
ts_spec_dict['path'] = os.path.join(ckpt_dir, ts_spec_dict['path'])
else:
raise ValueError(
'Incorrect TensorStore Spec. Expects "path" to be a key of spec or '
f'`spec["kvstore"]`. Got {ts_spec_dict}')
def _maybe_update_ts_from_file_to_gcs(ckpt_contents):
"""Updates the TensorStore driver from gfile to gcs."""
def _gfile_to_gcs_driver(arr_or_ts_spec_dict):
"""Converts the ts.Spec dict using gfile driver to gcs driver."""
if not isinstance(arr_or_ts_spec_dict, dict):
return arr_or_ts_spec_dict
if arr_or_ts_spec_dict['kvstore']['driver'] in ('file', 'gfile'):
ts_spec_dict = arr_or_ts_spec_dict
path = ts_spec_dict['kvstore'].pop('path')
# This will be updated to the actual bucket in `_read_ts`.
ts_spec_dict['kvstore'] = {
'bucket': 't5x-dummy-bucket',
'driver': 'gcs',
'path': path
}
else:
if arr_or_ts_spec_dict['kvstore']['driver'] != 'gcs':
raise ValueError('Unsupported TensoreStore driver. Got '
f'{arr_or_ts_spec_dict["kvstore"]["driver"]}.')
ts_spec_dict = arr_or_ts_spec_dict
return ts_spec_dict
def _is_leaf(value):
return not isinstance(
value, dict) or set(value.keys()) >= {'driver', 'kvstore', 'metadata'}
return jax.tree_map(_gfile_to_gcs_driver, ckpt_contents, is_leaf=_is_leaf)
def _maybe_update_ts_from_gcs_to_file(ckpt_contents):
"""Updates the TensorStore driver to gfile or file if different."""
# if saved in gcs, change to file
def _gcs_to_file_driver(arr_or_ts_spec_dict):
if not isinstance(arr_or_ts_spec_dict, dict):
return arr_or_ts_spec_dict
if arr_or_ts_spec_dict['kvstore']['driver'] == 'gcs':
ts_spec_dict = arr_or_ts_spec_dict
path = ts_spec_dict['kvstore'].pop('path')
driver = 'file'
ts_spec_dict['kvstore'] = {'path': path, 'driver': driver}
elif arr_or_ts_spec_dict['kvstore']['driver'] == 'gfile':
ts_spec_dict = arr_or_ts_spec_dict
driver = 'file'
ts_spec_dict['kvstore']['driver'] = driver
elif arr_or_ts_spec_dict['kvstore']['driver'] == 'file':
ts_spec_dict = arr_or_ts_spec_dict
else:
raise ValueError('Unsupported TensoreStore driver. Got '
f'{arr_or_ts_spec_dict["kvstore"]["driver"]}.')
return ts_spec_dict
def _is_leaf(value):
return not isinstance(
value, dict) or set(value.keys()) >= {'driver', 'kvstore', 'metadata'}
return jax.tree_map(_gcs_to_file_driver, ckpt_contents, is_leaf=_is_leaf)
class _BytesConditionVariable(object):
"""Wraps a condition variable to control concurrency based on bytes."""
def __init__(self, num_bytes):
self._max_bytes = num_bytes
self._num_bytes = num_bytes
self._cv = asyncio.Condition(lock=asyncio.Lock())
async def wait_for_bytes(self, n_bytes):
async with self._cv:
await self._cv.wait_for(lambda: self._num_bytes > n_bytes)
self._num_bytes -= n_bytes
assert self._num_bytes >= 0
async def return_bytes(self, n_bytes):
async with self._cv:
self._num_bytes += n_bytes
assert self._num_bytes <= self._max_bytes
self._cv.notify_all()
class SaveStateTransformationFn(typing_extensions.Protocol):
def __call__(self, state_dict: PyTreeDef,
parameter_infos: PyTreeDef) -> Tuple[PyTreeDef, PyTreeDef]:
"""Transforms the state and param info, e.g., by remapping parameters.
Args:
state_dict: State in the current model.
parameter_infos: PyTree containing `_ParameterInfo` objects.
Returns:
A tuple whose first element is the result of transforming `state_dict` and
whose second element is the result of transforming `parameter_infos`.
"""
class RestoreStateTransformationFn(typing_extensions.Protocol):
def __call__(self,
state_dict: PyTreeDef,
target_state_dict: PyTreeDef,
*,
is_resuming: bool = False) -> PyTreeDef:
"""Transforms the given checkpoint state, e.g., by remapping parameters.
Args:
state_dict: State to transform, which could be from a previous version of
the model.
target_state_dict: State in the current model.
is_resuming: `True` iff this restore call is due to a job resuming after
being temporarily stopped due to, for example, a preemption. This is
useful when there is restore logic that should run when restoring from
some pre-existing checkpoint, but that should not run again when
resuming from a newly-written checkpoint.
Returns:
The result of transforming the `state_dict`.
"""
class Checkpointer(object):
"""Handles saving and restoring potentially-sharded T5X checkpoints.
Checkpoints are stored using a combination of msgpack (via flax.serialization)
and TensorStore.
Parameters (and other objects) that are not partitioned are written to the
msgpack binary directly (by host 0). Partitioned parameters are each written
to their own TensorStore, with each host writing their portion to the same
TensorStore in parallel. If a partition is written on multiple hosts, the
partition is further sharded across these replicas to avoid additional
overhead. In place of the paramater, a `tensorstore.Spec` is written to the
msgpack (by host 0) as a reference to be used during restore. Note that the
path of the array being written is relative. This makes the checkpoints
portable. In other words, even if the checkpoint files are moved to a new
directory, they can still be loaded. Because the path is relative, the
checkpoint directory information has to be dynamically provided. This is done
by `_update_ts_path_from_relative_to_absolute`.
For TensorStore driver using Google Cloud Storage (GCS) Key-Value Storage
Layer, the GCS bucket information is necessary. When a checkpoint is written
using the gcs driver, we don't want to hardcode the bucket information in the
resulting file in order to maintain the portability. Therefore, we use a dummy
bucket name of "t5x-dummy-bucket". When reading or writing the checkpoint, the
bucket information is parsed from the checkpoint directory and the bucket
information is dynamically updated.
Attributes:
checkpoints_dir: a path to a directory to save checkpoints in and restore
them from.
keep: an optional maximum number of checkpoints to keep. If more than this
number of checkpoints exist after a save, the oldest ones will be
automatically deleted to save space.
restore_dtype: optional dtype to cast targets to after restoring.
save_dtype: dtype to cast targets to before saving.
keep_dataset_checkpoints: an optional maximum number of data iterators to
keep. If more than this number of data iterators exist after a save, the
oldest ones will be automatically deleted to save space.
"""
def __init__(self,
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
checkpoints_dir: str,
dataset_iterator: Optional[tf.data.Iterator] = None,
*,
keep: Optional[int] = None,
save_dtype: jnp.dtype = np.float32,
restore_dtype: Optional[jnp.dtype] = None,
use_gda: Optional[bool] = False,
keep_dataset_checkpoints: Optional[int] = None):
"""Checkpointer constructor.
Args:
train_state: A train state to be used to determine the structure of the
parameter tree, and the *full* (non-partitioned) parameter shapes and
dtypes. Saved and restored train states must match this structure.
partitioner: the partitioner to use for determining the local chunks
mapping or to perform params partitioning on restore.
checkpoints_dir: a path to a directory to save checkpoints in and restore
them from.
dataset_iterator: an optional iterator to save/restore.
keep: an optional maximum number of checkpoints to keep. If more than this
number of checkpoints exist after a save, the oldest ones will be
automatically deleted to save space.
save_dtype: dtype to cast targets to before saving.
restore_dtype: optional dtype to cast targets to after restoring. If None,
no parameter casting is performed.
use_gda: if True, enabled gda_lib.GlobalDeviceArray. Note: this is
currently an experimental feature under development.
keep_dataset_checkpoints: an optional maximum number of data iterators to
keep. If more than this number of data iterators exist after a save, the
oldest ones will be automatically deleted to save space.
"""
self._train_state = train_state
self._partitioner = partitioner
self.checkpoints_dir = checkpoints_dir
self.keep = keep
self.keep_dataset_checkpoints = keep_dataset_checkpoints
# Immutable due to use in `_get_parameter_infos`
self._save_dtype = save_dtype
self.restore_dtype = restore_dtype
self._dataset_ckpt = (
tf.train.Checkpoint(ds=dataset_iterator) if dataset_iterator else None)
self._use_gda = use_gda
if self._use_gda:
logging.info('Checkpointing using GDA format is enabled.')
data_layout = partitioner.get_data_layout()
self._dataset_ckpt_name = (
f'{_TRAIN_DS_PREFIX}-'
f'{data_layout.shard_id:03}-of-{data_layout.num_shards:03}')
self._should_write_dataset_ckpt = (
dataset_iterator and data_layout.is_first_host_in_replica_set)
self._parameter_infos = self._get_parameter_infos()
asyncio.set_event_loop(asyncio.new_event_loop())
def _get_state_dict_for_save(self,
state_dict: Dict[str, Any],
lazy_load: bool = True) -> Mapping[str, Any]:
"""Gets the optimizer state dict."""
def _lazy_load_device_array(arr):
if isinstance(arr, jax.xla.DeviceArray):
return LazyThreadPoolArray(arr.shape, arr.dtype, lambda: np.array(arr))
return arr
if lazy_load:
state_dict = jax.tree_map(_lazy_load_device_array, state_dict)
return state_dict
def _get_parameter_infos(self):
"""Generates the state dict of _ParameterInfos for the Optimizer.
We generate a state dict (matching the shape of the optimizer state dict)
that stores a _ParameterInfo for each parameter array.
The _ParameterInfo contains the TensorStore spec for the parameter array and
the LocalChunkInfo describing the slice of the array local to this host.
Returns:
The state dict of _ParameterInfo objects.
"""
def _get_param_info(name: str, arr: Any, axes: partitioning.PartitionSpec):
# If a node in your model is None it is probably a param_state that is not
# used because of a MultiOptimizer. We don't want to have any parameter
# info for it because it shouldn't be saved or restored.
if arr is None:
return None
# Pass-through empty dict leaves, which occur with optax EmptyState().
if isinstance(arr, dict) and not arr:
return {}
if axes is None:
return _ParameterInfo(
name=name,
shape=arr.shape,
ts_spec=None,
local_chunk_info=None,
axes=None)
if self._use_gda and isinstance(arr, gda_lib.GlobalDeviceArray):
local_chunk_info = None
metadata = gda_serialization._get_metadata(arr) # pylint: disable=protected-access
del metadata['dtype']
else:
local_chunk_info = self._partitioner.get_local_chunk_info(
arr.shape, axes)
write_shape = [
si if sl == slice(None) else sl.stop - sl.start
for si, sl in zip(arr.shape, local_chunk_info.slice)
]
# TODO(levskaya, adarob): how should we handle stacked/fused variables??
chunk_shape = _choose_chunk_shape(
write_shape,
target_elements=_DESIRED_CHUNK_SIZE_BYTES / arr.dtype.itemsize)
metadata = {
'compressor': {
'id': 'gzip'
},
'shape': arr.shape,
'chunks': np.array(chunk_shape),
}
if self.checkpoints_dir.startswith('gs://'):
spec = {
'driver': 'zarr',
'dtype': jnp.dtype(arr.dtype).name,
'kvstore': {
'driver': 'gcs',
# We always write with a dummy bucket and dynamically update the
# bucket information. This makes the checkpoint files portable
# and not bind to the bucket that it was originally written to.
'bucket': 't5x-dummy-bucket',
},
'path': name.replace('/', '.'),
'metadata': metadata,
}
else:
spec = {
'driver': 'zarr',
'dtype': jnp.dtype(arr.dtype).name,
'kvstore': {
'driver': 'file',
'path': name.replace('/', '.')
},
'metadata': metadata,
}
return _ParameterInfo(
name,
shape=arr.shape,
ts_spec=ts.Spec(spec),
local_chunk_info=local_chunk_info,
axes=axes)
# Create a tree of param names as the keys on the path to each leaf
# separated by "/".
param_names = traverse_util.unflatten_dict({
k: '/'.join(k) for k in traverse_util.flatten_dict(
self._train_state.state_dict(), keep_empty_nodes=True)
})
return jax.tree_map(
_get_param_info, param_names,
self._get_state_dict_for_save(self._train_state.state_dict()),
self._partitioner.get_mesh_axes(self._train_state).state_dict())
def _get_checkpoint_dir(self, step: int) -> str:
return get_checkpoint_dir(self.checkpoints_dir, step)
def all_steps(self) -> Sequence[int]:
"""Returns list of available step numbers in ascending order."""
return all_steps(self.checkpoints_dir)
def all_dataset_checkpoint_steps(self) -> Sequence[int]:
"""Returns list of available step numbers in ascending order."""
return all_dataset_checkpoint_steps(self.checkpoints_dir)
def latest_step(self) -> Optional[int]:
"""Returns latest step number or None if no checkpoints exist."""
return latest_step(self.checkpoints_dir)
def _remove_old_dataset_checkpoints(self):
"""Deletes old dataset checkpoints if there are more than allowed."""
if self.keep_dataset_checkpoints:
existing_steps = self.all_dataset_checkpoint_steps()
to_remove = len(existing_steps) - self.keep_dataset_checkpoints
if to_remove > 0:
for step in existing_steps[:to_remove]:
checkpoint_utils.remove_dataset_checkpoint(
self._get_checkpoint_dir(step), _TRAIN_DS_PREFIX)
def _remove_old_checkpoints(self):
"""Deletes oldest checkpoints if there are more than keep_checkpoints."""
if not self.keep:
return
existing_steps = self.all_steps()
to_remove = len(existing_steps) - self.keep
if to_remove <= 0:
return
for step in existing_steps[:to_remove]:
checkpoint_utils.remove_checkpoint_dir(self._get_checkpoint_dir(step))
def save(self,
train_state: train_state_lib.TrainState,
state_transformation_fns: Sequence[SaveStateTransformationFn] = (),
*,
concurrent_gb: int = 128):
"""Saves a checkpoint for the given train state.
Args:
train_state: the train state to save. May contain a combination of
LazyArray objects and arrays (e.g., np.ndarray, jax.DeviceArray)
state_transformation_fns: Transformations to apply, in order, to the state
before writing.
concurrent_gb: the approximate number of gigabytes of partitionable
parameters to process in parallel. Useful to preserve RAM.
"""
step = train_state.step
step = step.get() if isinstance(step, LazyArray) else step
step = _get_local_data(step)
# Integer, to avoid side effects in the checkpoint path.
step = int(step)
# Share a timestamp across devices.
timestamp = multihost_utils.broadcast_one_to_all(np.int32(time.time()))
final_dir = os.path.join(self.checkpoints_dir, f'checkpoint_{step}')
tmp_dir = final_dir + f'.tmp-{timestamp}'
if gfile.exists(final_dir):
logging.info(
'Skipping save checkpoint for step %d (directory %s already exists)',
step, final_dir)
return
logging.info('Saving checkpoint for step %d to %s', step, tmp_dir)
if jax.process_index() == 0:
gfile.makedirs(tmp_dir)
# Block all hosts until directory is ready.
multihost_utils.sync_global_devices(f'checkpointer:make_dir:{tmp_dir}')
written_state_dict = self._write_state_to_tensorstore(
tmp_dir, train_state, concurrent_gb, state_transformation_fns)
if self._should_write_dataset_ckpt:
logging.info("Writing dataset iterator state to '%s'.",
self._dataset_ckpt_name)
try:
self._dataset_ckpt.write(os.path.join(tmp_dir, self._dataset_ckpt_name))
except tf.errors.FailedPreconditionError as e:
logging.error(
'Input pipeline must be stateless in order to checkpoint. Cache '
'stateful steps offline or disable iterator checkpointing.')
raise e
# Block until complete on all hosts.
multihost_utils.sync_global_devices(
f'checkpointer:tensorstore_write_complete:{tmp_dir}')
if jax.process_index() == 0:
written_state_dict = jax.tree_map(_get_local_data, written_state_dict)
# Write msgpack file in host 0 only
msgpack_bytes = serialization.to_bytes({
'version': VERSION,
'optimizer': written_state_dict
})
with gfile.GFile(os.path.join(tmp_dir, 'checkpoint'), 'wb') as fp:
fp.write(msgpack_bytes)
# Finalize checkpoint directory.
if final_dir.startswith('gs://'):
subprocess.run(['gsutil', '-m', 'mv', tmp_dir, final_dir],
stdout=subprocess.DEVNULL,
check=True)
else:
gfile.rename(tmp_dir, final_dir)
logging.info('Saved checkpoint for step %d to %s', step, final_dir)
# Remove old checkpoints, if necessary.
self._remove_old_checkpoints()
self._remove_old_dataset_checkpoints()
# Block until complete on all hosts.
multihost_utils.sync_global_devices(
f'checkpointer:write_complete:{final_dir}')
def _write_state_to_tensorstore(
self,
ckpt_dir: str,
train_state: train_state_lib.TrainState,
concurrent_gb: int,
state_transformation_fns: Sequence[SaveStateTransformationFn],
) -> Mapping[str, Any]:
"""Writes extracted state from train state to Tensorstore."""
concurrent_bytes = concurrent_gb * 10**9
bytes_cv = _BytesConditionVariable(concurrent_bytes)
async def _write_array(maybe_arr: Any,
param_info: Optional[_ParameterInfo],
cast: bool = False):
"""Maybe write to TensorStore, returning object to write to msgpack.
Args:
maybe_arr: array or LazyArray to be written
param_info: ParameterInfo object. If None (or if param_info.ts_spec is
None), the array will be immediately returned without writing to
tensorstore. This is because array is None or is not partitioned, and
should be written separately.
cast: if True, performs cast operation using self._save_dtype.
Returns:
Tensorstore spec corresponding to the written array.
"""
if param_info is None or param_info.ts_spec is None:
# Write to the msgpack file on host 0.
if isinstance(maybe_arr, LazyArray):
return await maybe_arr.get_async()
return maybe_arr
# Only write each chunk of a parameter from one host
if self._use_gda or param_info.local_chunk_info.replica_id == 0:
arr = maybe_arr
# Wait until memory is available.
if isinstance(arr, gda_lib.GlobalDeviceArray):
n_bytes = sum([
shard.data.nbytes
for shard in arr.local_shards
if shard.replica_id == 0
])
else:
n_bytes = arr.nbytes
if n_bytes > concurrent_bytes:
logging.warning(
'Temporarily increasing the concurrency limits from %d bytes to '
'%d bytes to fit %s.', concurrent_bytes, n_bytes, param_info.name)
n_bytes = concurrent_bytes
await bytes_cv.wait_for_bytes(n_bytes)
if isinstance(maybe_arr, LazyArray):
arr = await arr.get_async()
elif not isinstance(arr, np.ndarray) and not isinstance(
arr, gda_lib.GlobalDeviceArray):
# Cast jax.DeviceArray to np.ndarray.
arr = np.array(maybe_arr, dtype=maybe_arr.dtype)
tmp_ts_spec_dict = param_info.ts_spec.to_json()
if cast:
# Set desired destination dtype.
tmp_ts_spec_dict['dtype'] = jnp.dtype(self._save_dtype).name
param_info.ts_spec = ts.Spec(tmp_ts_spec_dict)
# Path and gcs bucket (if applicable) information is updated in-place.
_update_ts_path_from_relative_to_absolute(ckpt_dir, tmp_ts_spec_dict)
if cast:
# Set up casting spec.
tmp_ts_spec_dict = {
'base': tmp_ts_spec_dict,
'driver': 'cast',
'dtype': jnp.dtype(arr.dtype).name, # dtype before cast
}
if self._use_gda:
await gda_serialization.async_serialize(arr, tmp_ts_spec_dict)
else:
t = await ts.open(
tmp_ts_spec_dict,
create=True,
open=True,
context=ts.Context({'file_io_concurrency': {
'limit': 128
}}))
await t[param_info.local_chunk_info.slice].write(arr)
await bytes_cv.return_bytes(n_bytes)
# N.B. we return the original ts_spec (before
# `_update_ts_path_from_relative_to_absolute` was called). This is because
# we'd like to keep the path as relative, i.e., it doesn't hardcode the
# directory that the checkpoint was originally written. This makes the
# checkpoints portable.
return param_info.ts_spec
transformed_state_dict, transformed_parameter_infos = (
self._transform_state_and_infos(train_state.state_dict(),
self._parameter_infos,
state_transformation_fns))
state_dict_for_save = self._get_state_dict_for_save(transformed_state_dict)
def _cast_arr_if_not_partitioned(maybe_arr, param_info):
if param_info is None or param_info.ts_spec is None:
return _cast(maybe_arr, self._save_dtype)
return maybe_arr
state_dict_for_save['target'] = jax.tree_multimap(
_cast_arr_if_not_partitioned, state_dict_for_save['target'],
transformed_parameter_infos['target'])
future_written_state = {}
for k in state_dict_for_save.keys():
# ensure that only 'target' is cast
future_written_state[k] = jax.tree_multimap(
functools.partial(_write_array, cast=(k == 'target')),
state_dict_for_save[k], transformed_parameter_infos[k])
# Block until complete on this host.
written_state_dict = _run_future_tree(future_written_state)
# Block until complete on all hosts.
multihost_utils.sync_global_devices(
f'checkpointer:ts_write_complete:{ckpt_dir}')
return written_state_dict
def _transform_state_and_infos(
self,
state_dict: PyTreeDef,
parameter_infos: PyTreeDef,
state_transformation_fns: Sequence[SaveStateTransformationFn],
) -> Tuple[PyTreeDef, PyTreeDef]:
"""Applies transformations to the state dict and parameter infos PyTrees."""
for fn in state_transformation_fns:
state_dict, parameter_infos = fn(state_dict, parameter_infos)
return state_dict, parameter_infos
def restore(
self,
step: Optional[int] = None,
path: Optional[str] = None,
state_transformation_fns: Sequence[RestoreStateTransformationFn] = (),
fallback_state: Optional[Mapping[str, Any]] = None,
lazy_parameters: bool = False) -> train_state_lib.TrainState:
"""Restores the host-specific parameters in an Optimizer.
Either `step` or `path` can be specified, but not both. If neither are
specified, restores from the latest checkpoint in the checkpoints directory.
Args:
step: the optional step number to restore from.
path: an optional absolute path to a checkpoint file to restore from.
state_transformation_fns: Transformations to apply, in order, to the state
after reading.
fallback_state: a state dict of an optimizer to fall back to for loading
params that do not exist in the checkpoint (after applying all
`state_transformation_fns`), but do exist in `Checkpointer.optimizer`.
The union of `fallback_state` and state loaded from the checkpoint must
match `Checkpointer.optimizer`.
lazy_parameters: whether to load the parameters as LazyArrays to preserve
memory.
Returns:
The restored train state.
Raises:
ValueError if both `step` and `path` are specified.
ValueError if checkpoint at `path` or `step` does not exist.
ValueError if `step` and `path` are not specified and no checkpoint is
found in the checkpoints directory.
"""
if lazy_parameters and self._partitioner.params_on_devices:
raise ValueError('Lazy Parameters cannot be copied to devices, please '
'set partitioner.params_on_devices=False.')
if step is not None and path is not None:
raise ValueError('At most one of `step` or `path` may be provided.')
if path:
ckpt_path = path
else:
if step is None:
step = self.latest_step()
if not step:
raise ValueError(f'No checkpoints found in {self.checkpoints_dir}.')
ckpt_path = self._get_checkpoint_dir(step)
if gfile.isdir(ckpt_path):
ckpt_dir = ckpt_path
ckpt_path = os.path.join(ckpt_path, 'checkpoint')
else:
ckpt_dir = os.path.dirname(ckpt_path)
if not gfile.exists(ckpt_path) or gfile.isdir(ckpt_path):
raise ValueError(f'Path is not a valid T5X checkpoint: {ckpt_path}')
logging.info('Restoring from checkpoint: %s', ckpt_path)
with gfile.GFile(ckpt_path, 'rb') as fp:
# TODO(adarob): Use threaded reading as in flax.checkpoints.
raw_contents = fp.read()
if raw_contents.startswith(b'model_checkpoint_path'):
raise ValueError(
'Attempting to restore a TensorFlow checkpoint as a native T5X '
'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' +
ckpt_path)
# `ckpt_contents['optimizer']` is a pytree with a realized np.array for
# leaves (params or states) written as msgpack and a ts.Spec (in a dict)
# for leaves written by TensorStore.
ckpt_contents = serialization.msgpack_restore(raw_contents)
# If reading a ckpt that was written with gfile driver but the current
# session uses the gcs driver, convert the ckpt's driver to gcs.
if ckpt_dir.startswith('gs://'):
ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents)
# If a ckpt was saved in gcs and is being loaded locally, then convert the
# driver to file or gfile. If the ckpt was not saved in gcs, do not change.
else:
ckpt_contents = _maybe_update_ts_from_gcs_to_file(ckpt_contents)
ckpt_state_dict = self._get_optimizer_state_dict(ckpt_contents,
state_transformation_fns)
# The state dict may contain TensorStore specs that need to be read.
dummy_spec = ts.Spec({'driver': 'zarr', 'kvstore': {'driver': 'memory'}})
# `dummy_written_state_dict` is a pytree with a `dummy_spec` for leaves
# (params or states) written as msgpack and a ts.Spec (in a dict) for leaves
# written by TensorStore.
dummy_written_state_dict = jax.tree_map(
lambda x: x.ts_spec or dummy_spec,
self._parameter_infos,
)
if fallback_state is None:
restore_parameter_infos = self._parameter_infos
else:
# If `fallback_state` was specified, restore only the subset
# of parameters matched by `self._get_optimizer_state_dict`. The
# rest will be provided by `fallback_state`.
dummy_written_state_dict = state_utils.intersect_state(
dummy_written_state_dict, ckpt_state_dict)
restore_parameter_infos = state_utils.intersect_state(
self._parameter_infos, ckpt_state_dict)
restore_parameter_infos_flat = state_utils.flatten_state_dict(
restore_parameter_infos)
for key in restore_parameter_infos_flat.keys():
logging.info('Restoring key from ckpt: %s', key)
# NB: `serialization.from_state_dict` doesn't check whether the shapes match
# at the leaf level. Non-partitioned leaves (e.g., optimizer states) can
# load arrays with inconsistent shapes.
# `written_state_dict` is a pytree with a realized np.array for leaves
# (params or states) written as msgpack and a `ts.Spec` for leaves written
# by TensorStore.
written_state_dict = serialization.from_state_dict(dummy_written_state_dict,
ckpt_state_dict)
state_dict = self._read_state_from_tensorstore(
ckpt_path,
written_state_dict,
restore_parameter_infos=restore_parameter_infos,
lazy_parameters=lazy_parameters)
# If `fallback_state` was specified, then fill the missing parameters.
if fallback_state is not None:
state_dict = state_utils.merge_state(state_dict, fallback_state)
for key in state_utils.flatten_state_dict(state_dict).keys():
if key not in restore_parameter_infos_flat:
logging.info('Not restoring key from ckpt: %s', key)
if self._dataset_ckpt:
logging.info("Restoring dataset iterator from '%s'.",
self._dataset_ckpt_name)
self._dataset_ckpt.read(os.path.join(
ckpt_dir, self._dataset_ckpt_name)).assert_consumed()
return self._restore_train_state(state_dict)
def _restore_train_state(
self,
state_dict: optimizers.OptimizerStateType) -> train_state_lib.TrainState:
"""Restores a TrainState from an Optimizer state_dict."""
train_state = self._train_state.restore_state(state_dict)
if not self._use_gda and self._partitioner.params_on_devices:
logging.info('Moving params to devices.')
train_state_axes = self._partitioner.get_mesh_axes(train_state)
train_state = self._partitioner.move_params_to_devices(
train_state, train_state_axes)
return train_state
def _create_lazy_awaitable_array(
self, param_info: _ParameterInfo, maybe_ts_spec: Any, ckpt_path: str,
restore_dtype: Optional[jnp.dtype]) -> LazyAwaitableArray:
"""Creates LazyArray from tensorstore.
Does not materialize the array immediately.
Args:
param_info: Information about how to read the parameter, host based sliced
reads and the like.
maybe_ts_spec: The tensorstore spec to read the parameter or some other
object. If this is an array then we will do a host based sliced read on
it (provided the param_info says to). Anything else we just return.
ckpt_path: A base location to use when resolving the relative paths in the
tensorstore spec.
restore_dtype: type to restore as. None indicates that no cast is
requested.
Returns:
LazyArray object.
"""
mesh = None
axes = None
if self._use_gda:
mesh = self._partitioner.mesh
axes = param_info.axes
get_fn = functools.partial(
_read_ts,
param_info,
maybe_ts_spec,
ckpt_path=ckpt_path,
restore_dtype=restore_dtype,
mesh=mesh,
axes=axes)
return LazyAwaitableArray.from_tensor_store_spec_or_array(
maybe_ts_spec, get_fn, dtype=restore_dtype)
def _read_state_from_tensorstore(
self,
ckpt_path: str,
written_state_dict: Mapping[str, Any],
restore_parameter_infos: Optional[Mapping[str, Any]] = None,
lazy_parameters: bool = False,
) -> Mapping[str, Any]:
"""Sets up lazy reads from Tensorstore and returns them as a state_dict."""
if restore_parameter_infos is None:
restore_parameter_infos = self._parameter_infos
# Replace TensorStore Specs with the lazy array values.
state_dict = {}
for k in written_state_dict.keys():
# ensure that only 'target' is cast
restore_dtype = self.restore_dtype if k == 'target' else None
state_dict[k] = jax.tree_multimap(
functools.partial(
self._create_lazy_awaitable_array,
ckpt_path=ckpt_path,
restore_dtype=restore_dtype), restore_parameter_infos[k],
written_state_dict[k])
if not lazy_parameters:
future_state_dict = jax.tree_map(lambda x: x.get_async(), state_dict)
state_dict = _run_future_tree(future_state_dict)
if self.restore_dtype is not None:
state_dict['target'] = _cast(state_dict['target'], self.restore_dtype)
return state_dict
def restore_from_tf_checkpoint(
self,
path_or_dir: str,
strict: bool = True,
translator: Optional[checkpoint_importer.CheckpointTranslator] = None
) -> train_state_lib.TrainState:
"""Restore from a TensorFlow-based T5 checkpoint."""
full_state_dict = checkpoint_importer.restore_from_t5_checkpoint(
self._train_state.state_dict(),
path_or_dir,
lazy_parameters=False,
strict=strict,
translator=translator)
def _partition_parameter(maybe_arr: Any, param_info: _ParameterInfo):
if isinstance(maybe_arr, np.ndarray) and param_info:
arr = maybe_arr
if param_info.shape is not None and arr.shape != param_info.shape:
raise ValueError(
f'Shape of `{param_info.name}` in checkpoint {arr.shape} does '
f'not match expected {param_info.shape}.')
if param_info.local_chunk_info:
arr = arr[param_info.local_chunk_info.slice]
return arr
return maybe_arr
state_dict = jax.tree_multimap(_partition_parameter, full_state_dict,
self._parameter_infos)
if self.restore_dtype is not None:
state_dict['target'] = _cast(state_dict['target'], self.restore_dtype)
return self._restore_train_state(state_dict)
def convert_from_tf_checkpoint(
self,
path_or_dir: str,
*,
state_transformation_fns: Sequence[SaveStateTransformationFn] = (),
concurrent_gb: int = 16,
translator: Optional[checkpoint_importer.CheckpointTranslator] = None):
"""Convert from a TensorFlow-based T5 checkpoint."""
full_state_dict = checkpoint_importer.restore_from_t5_checkpoint(
self._train_state.state_dict(),
path_or_dir,
lazy_parameters=True,
translator=translator)
train_state = self._train_state.restore_state(full_state_dict)
self.save(
train_state,
state_transformation_fns=state_transformation_fns,
concurrent_gb=concurrent_gb)
def _get_optimizer_state_dict(
self, ckpt_contents: PyTreeDef,
state_transformation_fns: Sequence[RestoreStateTransformationFn]):
return _get_optimizer_state_dict(ckpt_contents,
self._train_state.state_dict(),
state_transformation_fns)
class CheckpointerConstructor(typing_extensions.Protocol):
"""A function that returns a checkpoints.Checkpointer.
This type annotation allows users to partially bind args to the constructors
of Checkpointer subclasses without triggering type errors.
"""
def __call__(self,
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
checkpoints_dir: str,
dataset_iterator: Optional[tf.data.Iterator] = None,
*,
keep: Optional[int] = None,
save_dtype: jnp.dtype = np.float32,
restore_dtype: Optional[jnp.dtype] = None,
use_gda: Optional[bool] = False,
keep_dataset_checkpoints: Optional[int] = None) -> Checkpointer:
"""Checkpointer constructor.
Args:
train_state: A train state to be used to determine the structure of the
parameter tree, and the *full* (non-partitioned) parameter shapes and
dtypes. Saved and restored train states must match this structure.
partitioner: the partitioner to use for determining the local chunks
mapping or to perform params partitioning on restore.
checkpoints_dir: a path to a directory to save checkpoints in and restore
them from.
dataset_iterator: an optional iterator to save/restore.
keep: an optional maximum number of checkpoints to keep. If more than this
number of checkpoints exist after a save, the oldest ones will be
automatically deleted to save space.
save_dtype: dtype to cast targets to before saving.
restore_dtype: optional dtype to cast targets to after restoring. If None,
no parameter casting is performed.
use_gda: if True, enabled gda_lib.GlobalDeviceArray. Note: this is
currently an experimental feature under development.
keep_dataset_checkpoints: an optional maximum number of data iterators to
keep. If more than this number of data iterators exist after a save, the
oldest ones will be automatically deleted to save space.
"""
pass
class SaveBestCheckpointer(Checkpointer):
"""A Checkpointer class that keeps checkpoints based on 'best' metrics.
This extends the standard Checkpointer to garbage collect checkpoints based on
metric values, instead of step recency. It uses Tensorboard summary files to
determine best values for a given user configured metric name. Events are read
and parsed using Tensorboard's event_processing packages.
The metric name must be of the form `{run_name}/{tag_name}`. For example,
'train/accuracy' or 'inference_eval/glue_cola_v002/eval/accuracy'.
A few important features of this checkpointer:
- Fallback behavior. It is not possible to verify whether metric names are
valid during initialization, since some metrics may get written out after
some time (e.g., during an evaluation). As such, when user provided metric
names are not found, this checkpointer can be configured for two fall back
strategies: (1) if `keep_checkpoints_without_metrics` is False, we use to
the "most recent checkpoint" strategy from the standard checkpointer, (2)
if `keep_checkpoints_without_metrics` is True, we keep all checkpoints until
metrics become available (potentially indefinitely if summary files have
been deleted or corrupted).
- The number of checkpoints to keep is always increased by 1. Since its
crucial to always keep the latest checkpoint (for recovery purposes) we
always store the latest checkpoint plus `keep` number of best checkpoints.
- It is assumed that Tensorboard summaries (event) files share a common root
directory with `checkpoint_dir`, which is the directory passed to the
the logdir crawler that searches for event files.
Attributes:
checkpoints_dir: a path to a directory to save checkpoints in and restore
them from.
keep: an optional maximum number of checkpoints to keep. If more than this
number of checkpoints exist after a save, the oldest ones will be
automatically deleted to save space.
restore_dtype: optional dtype to cast targets to after restoring.
save_dtype: dtype to cast targets to before saving.
metric_name_to_monitor: Name of metric to monitor. Must be in the format
{run_name}/{tag_name} (e.g., 'train/accuracy',
'inference_eval/glue_cola_v002/eval/accuracy').
metric_mode: Mode to use to compare metric values. One of 'max' or 'min'.
keep_checkpoints_without_metrics: Whether to always keep (or delete)
checkpoints for which a metric value has not been found.
force_keep_period: When removing checkpoints, skip those who step is
divisible by force_keep_period (step % force_keep_period == 0).
use_gda: Enables GDA (see Checkpointer).
keep_dataset_checkpoints: an optional maximum number of data iterators to
keep. If more than this number of data iterators exist after a save, the
oldest ones will be automatically deleted to save space.
"""
def __init__(self,
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
checkpoints_dir: str,
dataset_iterator: Optional[tf.data.Iterator] = None,
*,
keep: Optional[int] = None,
save_dtype: jnp.dtype = np.float32,
restore_dtype: Optional[jnp.dtype] = None,
metric_name_to_monitor: str = 'train/accuracy',
metric_mode: str = 'max',
keep_checkpoints_without_metrics: bool = True,
force_keep_period: Optional[int] = None,
use_gda: bool = False,
keep_dataset_checkpoints: Optional[int] = None):
super().__init__(
train_state,
partitioner,
checkpoints_dir,
dataset_iterator,
keep=keep,
save_dtype=save_dtype,
restore_dtype=restore_dtype,
use_gda=use_gda,
keep_dataset_checkpoints=keep_dataset_checkpoints)
if metric_mode not in ('max', 'min'):
raise ValueError('Unsupported `metric_mode`: %s' % metric_mode)
# Metric run and tag names are derived from metric_name_to_monitor and are
# filled in _try_fill_metric_run_and_tag_names().
self._metric_run: Optional[str] = None
self._metric_tag: Optional[str] = None
self._metric_name_to_monitor = metric_name_to_monitor
self._metric_mode = metric_mode
self._keep_checkpoints_without_metrics = keep_checkpoints_without_metrics
self._force_keep_period = force_keep_period
logging.info('Using SaveBestCheckpointer to keep %s best (%s) metric %s',
keep, metric_mode, metric_name_to_monitor)
def _populate_metrics_for_steps(self,
steps: Iterable[int]) -> Mapping[int, float]:
"""Iterate through summary event files and return metrics for `steps`."""
metrics_by_step = {}
for subdir in io_wrapper.GetLogdirSubdirectories(self.checkpoints_dir):
rpath = os.path.relpath(subdir, self.checkpoints_dir)
# Skip runs that do not match user-specified metric.
if ((not self._metric_run and not self._try_fill_metric_run_and_tag_names(
(rpath,))) or self._metric_run != rpath):
logging.info('Skipping events in %s', subdir)
continue
logging.info('Looking for events in %s', subdir)
loader = directory_watcher.DirectoryWatcher(
subdir, event_file_loader.EventFileLoader,
io_wrapper.IsTensorFlowEventsFile)
for event in loader.Load():
# Skip metric collection of events for unavailable checkpoints or for
# unmonitored tags.
if (event.step not in steps or not event.summary.value or
event.summary.value[0].tag != self._metric_tag):
continue
metric_value = tf.make_ndarray(event.summary.value[0].tensor)
metrics_by_step[event.step] = metric_value
return metrics_by_step
def _try_fill_metric_run_and_tag_names(self, run_keys: Iterable[str]) -> bool:
"""Extract metric run and tag names by matching one of the `run_keys`.
This function tries to greedily split user-provided metric_name_to_monitor
into {run} and {tag} components. It does so by trying to match all available
{run}/{tag} names in the provided run_keys. If successful, populates
self._metric_run and self._metric_tag.
Args:
run_keys: Set of run keys to test for.
Returns:
Whether metric name prefix matches one of the run keys, and, as a
side-effect, populates self._metric_run and self._metric_tag.
"""
metric_run, metric_tag = None, None
# Query existing events for different run and tags to match with user
# provided metric name.
m = self._metric_name_to_monitor.split('/')
possible_run_names = ['/'.join(m[:i]) for i in range(1, len(m))]
for key in run_keys:
for possible_run_name in possible_run_names:
if key == possible_run_name:
metric_run = possible_run_name
metric_tag = self._metric_name_to_monitor[len(metric_run) + 1:]
break
if metric_run and metric_tag:
self._metric_run, self._metric_tag = metric_run, metric_tag
return True
return False
def _filter_out_force_keep_period_steps(self, existing_steps):
"""Filter out steps that are divisible by keep_period excluding the last."""
if not existing_steps:
return existing_steps
# Don't filter out the last step.
last_step = existing_steps.pop()
existing_steps = [
s for s in existing_steps if s % self._force_keep_period != 0
]
return existing_steps + [last_step]
def _remove_old_checkpoints(self):
"""Deletes checkpoints if there are more than keep_checkpoints."""
if not self.keep:
return
existing_steps = self.all_steps()
if self._force_keep_period:
# Ignore checkpoints whose step is divisible by the keep period.
existing_steps = self._filter_out_force_keep_period_steps(existing_steps)
# Artificially add 1 to `keep` since we always keep the latest checkpoint.
if len(existing_steps) <= self.keep + 1:
return
# Synchronous fetch of new events for existing_steps.
metrics_by_step = self._populate_metrics_for_steps(existing_steps)
logging.info('SaveBestcheckpointer: collected metrics %s', metrics_by_step)
# Re-sort existing_steps by metric values while always keeping the latest
# checkpoint.
latest_checkpoint = existing_steps[-1]
existing_steps = existing_steps[:-1]
if self._keep_checkpoints_without_metrics:
existing_steps = list(
filter(lambda s: s in metrics_by_step, existing_steps))
to_remove = len(existing_steps) - self.keep
if to_remove <= 0:
return
# For any remaining steps without metrics, we assign a low/high value which
# will make them candidate for removal. If no metrics are found this sorting
# should preserve current order (oldest first).
not_found_value = float('-inf' if self._metric_mode == 'max' else 'inf')
existing_steps = sorted(
existing_steps,
key=lambda step: metrics_by_step.get(step, not_found_value),
reverse=(self._metric_mode != 'max'))
existing_steps.append(latest_checkpoint)
for step in existing_steps[:to_remove]:
checkpoint_utils.remove_checkpoint_dir(self._get_checkpoint_dir(step))
def _get_optimizer_state_dict(
ckpt_contents: PyTreeDef, optimizer_state: Mapping[str, Any],
state_transformation_fns: Sequence[RestoreStateTransformationFn]):
"""Extracts optimizer state dict contents and applies assignment map."""
version = ckpt_contents.get('version', 0)
if version == 0:
# This is a standard Flax checkpoint and may require remapping below.
ckpt_optimizer_state = ckpt_contents
else:
ckpt_optimizer_state = ckpt_contents['optimizer']
if version >= 2:
for fn in state_transformation_fns:
ckpt_optimizer_state = fn(ckpt_optimizer_state, optimizer_state)
return ckpt_optimizer_state
else:
raise ValueError('Checkpoint versions earlier than 2 are not supported. ' # pylint: disable=unreachable
f'Got version: {version}')
async def _read_ts(param_info: _ParameterInfo,
maybe_tspec: Any,
ckpt_path: str,
restore_dtype: Optional[jnp.dtype] = None,
mesh: Optional[gda_lib.Shape] = None,
axes: Optional[gda_lib.MeshAxes] = None):
"""Read from a tensorstore.
If both `mesh` and `axes` are provided, the method will attempt to restore the
array as a GlobalDeviceArray.
Note:
We use param_infos as the first argument because this function is only used
in `jax.tree_multimap` calls. In a tree multimap if the leaf of the first
tree is `None` then is is ignored, even if the second tree has a subtree
at that point. This means that when we are using something like a
MultiOptimizer we can set the parameter info for a variable to `None` and
we can skip processing it, even if the checkpoint has a subtree with things
like optimizer state variables in it.
Args:
param_info: Information about how to read the parameter, host based sliced
reads and the like.
maybe_tspec: The tensorstore spec to read the parameter or some other
object. If this is an array then we will do a host based sliced read on it
(provided the param_info says to). Anything else we just return.
ckpt_path: A base location to use when resolving the relative paths in the
tensorstore spec.
restore_dtype: type to restore as. None indicates that no cast is requested.
mesh: Mesh object for GDA restoration.
axes: MeshAxes object for GDA restoration.
Returns:
The array. Depending on the value `maybe_tspec` it might be read from
tensorstore, or it might be returned as is. Depending on the values in
param_info (specifically the `local_chunk_info`) it might be the full value
or a specific slice.
"""
# If saved as a numpy array, but a partitioned read is requested, return a
# slice of the array for that host. Otherwise, return the whole thing.
if isinstance(maybe_tspec, np.ndarray) and param_info:
if param_info.local_chunk_info:
arr = maybe_tspec
return arr[param_info.local_chunk_info.slice]
else:
return maybe_tspec
# If we have anything else that isn't a tensorstore spec just return it.
elif not isinstance(maybe_tspec, ts.Spec):
return maybe_tspec
tmp_ts_spec_dict = maybe_tspec.to_json()
# Remove non-required params so that we can open Tensorstore
# that was created with a different set of params.
del tmp_ts_spec_dict['metadata']['chunks']
del tmp_ts_spec_dict['metadata']['compressor']
# Convert the relative path in the spec to a path based on the checkpoint
# location. Path and gcs bucket (if applicable) information is updated
# in-place.
_update_ts_path_from_relative_to_absolute(
os.path.dirname(ckpt_path), tmp_ts_spec_dict)
if param_info.shape is not None:
ts_spec_arr_shape = tuple(tmp_ts_spec_dict['metadata']['shape'])
# Check that the shapes of the array on disk match the expected shape based
# on the optimizer that is being restored.
if ts_spec_arr_shape != param_info.shape:
raise ValueError(f'Shape of `{param_info.name}` in checkpoint '
f'{ts_spec_arr_shape} does not match expected '
f'{param_info.shape}.')
if ('dtype' in tmp_ts_spec_dict and tmp_ts_spec_dict['dtype']
== 'uint16') or ('dtype' in tmp_ts_spec_dict['metadata'] and
tmp_ts_spec_dict['metadata']['dtype'] == '<u2'):
raise ValueError(
f'Found unsupported uint16 type in Tensorstore spec: {tmp_ts_spec_dict}. '
'Please use t5x/google/scripts/convert_uint16_checkpoint.py '
'to update saved types to bfloat16.')
if restore_dtype is not None:
tmp_ts_spec_dict = {
'base': tmp_ts_spec_dict,
'driver': 'cast',
'dtype': jnp.dtype(restore_dtype).name
}
if mesh is None or axes is None:
# Read the array.
t = await ts.open(tmp_ts_spec_dict, open=True)
if param_info.local_chunk_info is not None:
# Just read the subsection we care about.
t = t[param_info.local_chunk_info.slice]
arr = await t.read()
else:
# if provided, read as GDA
arr = await gda_serialization.async_deserialize(mesh, axes,
tmp_ts_spec_dict)
return arr
def fake_param_info(maybe_tspec: Any) -> Optional[_ParameterInfo]:
"""Create _ParameterInfo that results in a full read."""
# tspec is only None for `param_states` where the associated variable
# is not updated by any optimizers. By setting the parameter info for
# this to None, we can later short circut processing these subtrees
# during loading.
if maybe_tspec is None:
return None
local_chunk_info = None
tspec = None
if isinstance(maybe_tspec, ts.Spec):
tspec = maybe_tspec
local_chunk_info = partitioning.LocalChunkInfo(
slice=(slice(None, None),), replica_id=0)
return _ParameterInfo(
name='', # We don't ever use the name.
shape=tuple(tspec.to_json()['metadata']['shape']) if tspec else None,
# We just believe the spec in the file.
ts_spec=tspec,
local_chunk_info=local_chunk_info,
axes=None)
def find_checkpoint(path: str, step: Optional[int] = None) -> str:
"""Find the checkpoint file based on paths and steps.
Args:
path: The location of the checkpoint. Can point to the `model_dir`, the
checkpoint dir with a step, or the actual checkpoint file.
step: The step to load. Only used if you are pointing to the `model_dir`
Raises:
ValueError if the checkpoint file can't be found.
Returns:
The path to the checkpoint file.
"""
# If you aren't pointing at the msgpack checkpoint file
if gfile.isdir(path):
# If you didn't specify a step
if step is None:
# Try to get the most recent step.
step = latest_step(path)
# If you found a step then you were pointing at model_dir, set the path to
# the msgpack file in the checkpoint dir.
if step:
path = get_checkpoint_dir(path, step)
# You gave a step, use it.
else:
path = get_checkpoint_dir(path, step)
# Whether you supplied a step, found a step, or were already pointing at the
# step, you are not pointing at a step directory, so now point to the
# msgpack file.
path = os.path.join(path, 'checkpoint')
# You weren't point to a dir so you were pointing at the msgpack file.
# Check that we found a checkpoint file.
if not gfile.exists(path) or gfile.isdir(path):
raise ValueError(f'Path is not a valid checkpoint: {path}')
return path
def load_t5x_checkpoint(
path: str,
step: Optional[int] = None,
state_transformation_fns: Sequence[RestoreStateTransformationFn] = (),
remap: bool = True,
restore_dtype: Optional[jnp.dtype] = None,
lazy_parameters: bool = False) -> PyTreeDef:
"""Load a T5X checkpoint without pre-defining the optimizer.
Note:
This only works for T5X checkpoints, not TF checkpoints.
Args:
path: The location of the checkpoint.
step: The checkpoint from which step should be loaded.
state_transformation_fns: Transformations to apply, in order, to the state
after reading.
remap: Whether to rename the checkpoint variables to the newest version.
restore_dtype: optional dtype to cast targets to after restoring. If None,
no parameter casting is performed.
lazy_parameters: whether to load the parameters as LazyArrays to preserve
memory.
Returns:
A nested dictionary of weights and parameter states from the checkpoint.
"""
path = find_checkpoint(path, step)
logging.info('Restoring from checkpoint: %s', path)
# The msgpack file will have all the info we need about the parameter layout.
with gfile.GFile(path, 'rb') as fp:
ckpt_contents = serialization.msgpack_restore(fp.read())
# If reading a ckpt that was written with gfile driver but the current
# session uses the gcs driver, convert the ckpt's driver to gcs.
if path.startswith('gs://'):
ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents)
# If a ckpt was saved in gcs and is being loaded locally, then convert the
# driver to file or gfile. If the ckpt was not saved in gcs, do not change.
else:
ckpt_contents = _maybe_update_ts_from_gcs_to_file(ckpt_contents)
# Remap that variable names to the most recent formatting.
if remap:
ckpt_optimizer_state = _get_optimizer_state_dict(ckpt_contents, {},
state_transformation_fns)
# If we aren't remapping names we at least need to index into the checkpoint
# file blob to make sure we are only dealing with the optimizer state.
else:
# Grab a subsection of the file depending on the version.
version = ckpt_contents.get('version', 0)
if version == 0:
ckpt_optimizer_state = ckpt_contents
else:
ckpt_optimizer_state = ckpt_contents['optimizer']
# Replace all dicts of tensorstore specs with actual `ts.Spec`s.
# When a checkpoint was trained using a MultiOptimizer, some of the parameter
# states may be set to `None` (when a parameter was untouched by any
# optimizer). We still needs references to these in our state so we keep
# empty nodes.
ckpt_optimizer_state_with_specs = (
state_utils.flatten_state_dict(
ckpt_optimizer_state, keep_empty_nodes=True))
ckpt_optimizer_state_with_specs = {
k: ts.Spec(v) if isinstance(v, dict) else v
for k, v in ckpt_optimizer_state_with_specs.items()
}
# Create fake parameter info that results in reading the whole variable.
param_infos = {
k: fake_param_info(v) for k, v in ckpt_optimizer_state_with_specs.items()
}
ckpt_optimizer_state_with_specs = traverse_util.unflatten_dict(
ckpt_optimizer_state_with_specs, sep='/')
param_infos = traverse_util.unflatten_dict(param_infos, sep='/')
def _create_lazy_awaitable_array(
param_info: _ParameterInfo, maybe_ts_spec: Any, ckpt_path: str,
restore_dtype: Optional[jnp.dtype]) -> LazyAwaitableArray:
get_fn = functools.partial(
_read_ts,
param_info,
maybe_ts_spec,
ckpt_path=ckpt_path,
restore_dtype=restore_dtype)
return LazyAwaitableArray.from_tensor_store_spec_or_array(
maybe_ts_spec, get_fn, dtype=restore_dtype)
state_dict = jax.tree_multimap(
functools.partial(
_create_lazy_awaitable_array,
ckpt_path=path,
restore_dtype=restore_dtype), param_infos,
ckpt_optimizer_state_with_specs)
if not lazy_parameters:
future_state_dict = jax.tree_map(lambda x: x.get_async(), state_dict)
state_dict = _run_future_tree(future_state_dict)
if restore_dtype is not None:
state_dict['target'] = _cast(state_dict['target'], restore_dtype)
return state_dict