Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
r"""Script to pretrain or finetune in JAX using a SeqIO pipeline. | |
""" | |
import functools | |
import math | |
import os | |
import time | |
from typing import Callable, Sequence, Mapping, Tuple, Type, Optional | |
# Set Linen to add profiling information when constructing Modules. | |
# Must be set before flax imports. | |
# pylint:disable=g-import-not-at-top | |
os.environ['FLAX_PROFILE'] = 'true' | |
# TODO(adarob): Re-enable once users are notified and tests are updated. | |
os.environ['FLAX_LAZY_RNG'] = 'no' | |
from absl import logging | |
from clu import metric_writers | |
import clu.data | |
import jax | |
from jax import random | |
from jax.experimental import multihost_utils | |
import jax.numpy as jnp | |
import numpy as np | |
import seqio | |
from t5x import models | |
from t5x import partitioning | |
from t5x import train_state as train_state_lib | |
from t5x import trainer as trainer_lib | |
from t5x import utils | |
import tensorflow as tf | |
# Automatically search for gin files relative to the T5X package. | |
_DEFAULT_GIN_SEARCH_PATHS = [ | |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
] | |
PyTreeDef = type(jax.tree_structure(None)) | |
P = partitioning.PartitionSpec | |
# Special key that used to distinguish train metrics. | |
TRAIN_METRIC_KEY = 'train' | |
# String keys that is acceptable from config. | |
_ACTION_KEYS = frozenset(trainer_lib.ActionMode.__members__.keys()) | |
def run_actions( | |
mode: trainer_lib.ActionMode, actions: trainer_lib.ActionMapType, | |
train_state: train_state_lib.TrainState, | |
metrics_by_task: Mapping[str, trainer_lib.MetricValueMapType]) -> bool: | |
"""Invokes all actions on the given mode on host 0, then broadcasts to all. | |
Args: | |
mode: The mode to run the actions. e.g., if mode is `train`, only actions | |
configured to run with `train` mode will be invoked. | |
actions: A mapping of actions that runs after train, eval or infer_eval, to | |
inspect the model and perform useful operations, e.g., early stopping. | |
train_state: The current train_state of the trainer. | |
metrics_by_task: A map of metrics keyed by task name. | |
Returns: | |
A bool indicating whether training should be halted. | |
Raises: | |
RuntimeError: When the metrics processed on host 0 is None. | |
""" | |
stop_training = False | |
if jax.process_index() == 0: | |
if not metrics_by_task: | |
raise RuntimeError('Metric is unexpectedly empty on process 0') | |
for action in actions.get(mode, []): | |
stop_training |= action.run(train_state, metrics_by_task=metrics_by_task) | |
# Broadcast result from host 0 to others. | |
return bool(multihost_utils.broadcast_one_to_all(jnp.array(stop_training))) | |
def train( | |
*, | |
model: models.BaseTransformerModel, | |
train_dataset_cfg: utils.DatasetConfig, | |
train_eval_dataset_cfg: Optional[utils.DatasetConfig], | |
infer_eval_dataset_cfg: Optional[utils.DatasetConfig], | |
checkpoint_cfg: utils.CheckpointConfig, | |
partitioner: partitioning.BasePartitioner, | |
trainer_cls: Type[trainer_lib.BaseTrainer], | |
model_dir: str, | |
total_steps: int, | |
eval_steps: int, | |
eval_period: int, | |
stats_period: Optional[int] = None, | |
random_seed: Optional[int], | |
use_hardware_rng: bool = False, | |
summarize_config_fn: Callable[[str, metric_writers.MetricWriter, int], | |
None], | |
inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator, | |
get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset, | |
concurrent_metrics: bool = True, | |
actions: Optional[Mapping[str, Sequence[trainer_lib.BaseAction]]] = None, | |
train_eval_get_dataset_fn: Optional[utils.GetDatasetCallable] = None, | |
run_eval_before_training: bool = False, | |
use_gda: bool = False) -> Tuple[int, train_state_lib.TrainState]: | |
"""Train function. | |
Args: | |
model: The model object to use for training. | |
train_dataset_cfg: Specification for the dataset to train with. | |
train_eval_dataset_cfg: Specification for the dataset to evaluate with using | |
the train metrics and no inference (e.g., uses teacher forcing). If None, | |
train eval is disabled. | |
infer_eval_dataset_cfg: Specification for the dataset to evaluate with using | |
the inference metrics (e.g., uses sampled decoding). If None, inference | |
eval is disabled. | |
checkpoint_cfg: Specification for saving and restoring model parameters and | |
dataset state to/from checkpoints. | |
partitioner: Partitioner for model parameters and data across devices. | |
trainer_cls: An implementation of BaseTrainer. | |
model_dir: Path of directory to store checkpoints and metric summaries. | |
total_steps: The step number to stop training after. The number of actual | |
steps trained in this run will be this number minus the starting step from | |
the checkpoint. | |
eval_steps: The number of batches to process for each train-eval loop. | |
eval_period: The number of train steps between each evaluation (both | |
train-eval and infer-eval). | |
stats_period: The number of train steps between writing scalar stats. If | |
None, defaults to eval_period. | |
random_seed: A random seed to use for dropout and initialization. If None, a | |
fast, non-deterministic hardware-based RNG is used. | |
use_hardware_rng: Whether to force using the RngBitGenerator based hardware | |
rng, which takes seeds and acts similarly to software PRNG in that it | |
should be seed-deterministic. The new RngBitGenerator custom PRNG system | |
should be reproducible for a given sharding, but the numbers will change | |
for different shardings of the same model. | |
summarize_config_fn: A function that takes in the model directory, a | |
SummaryWriter, and the step number, and writes a summary of the | |
inference_evaluator_cls: seqio.Evaluator class to use for inference | |
evaluation, potentially with bound configuration args. | |
get_dataset_fn: The callable use to get the train and train-eval datasets | |
based on the DatasetConfig and shard information. | |
concurrent_metrics: If True, allow metrics computation and logging to | |
overlap with training. Will likely result in additional TPU memory usage. | |
actions: A mapping of actions that runs after train, eval or infer_eval, to | |
inspect the model and perform useful operations, e.g., early stopping. The | |
key must have a 1:1 mapping to ActionMode enum. For EVAL actions to | |
actually work, this requires `concurrent_metrics` to be turned off, since | |
chaining futures and mutating states concurrently might be error-prone. | |
train_eval_get_dataset_fn: Optional callable use to get the train-eval | |
datasets based on the DatasetConfig and shard information. If missing, it | |
defaults to `get_dataset_fn`. | |
run_eval_before_training: If True, calculate training eval and inference | |
eval metrics before training begins. | |
use_gda: if True, uses GlobalDeviceArray. Experimental feature. | |
Returns: | |
The tuple of (last_step, last_train_state). | |
""" | |
logging.info('Process ID: %d', jax.process_index()) | |
tf.io.gfile.makedirs(model_dir) | |
jax.config.update('jax_parallel_functions_output_gda', use_gda) | |
# Each "epoch" of the training loop should be the min of the eval period, | |
# checkpoint period or the full training. | |
# We compute here to ensure that the eval period and checkpoint period are | |
# divisible by this number, otherwise we fail. | |
eval_enabled = (train_eval_dataset_cfg or infer_eval_dataset_cfg) | |
eval_period = eval_period if eval_enabled else 0 | |
checkpoint_period = checkpoint_cfg.save.period if checkpoint_cfg.save else 0 | |
if eval_period or checkpoint_period: | |
steps_per_epoch = min(eval_period or np.inf, checkpoint_period or np.inf) | |
else: | |
steps_per_epoch = total_steps | |
stats_period = stats_period or steps_per_epoch | |
if (eval_period and eval_period % steps_per_epoch or | |
checkpoint_period and checkpoint_period % steps_per_epoch): | |
raise ValueError( | |
f'Checkpoint period ({checkpoint_period}) must evenly divide eval ' | |
f'period ({eval_period}), or vice-versa.') | |
if use_hardware_rng or random_seed is None: | |
logging.info( | |
'Using fast RngBitGenerator PRNG for initialization and dropout.') | |
if random_seed is None: | |
random_seed = multihost_utils.broadcast_one_to_all(np.int32(time.time())) | |
logging.info('Random seed not provided, using RNG seed %s', random_seed) | |
else: | |
logging.warning( | |
'When using hardware RNG with a fixed seed, repeatability is only ' | |
'guaranteed for fixed hardware and partitioning schemes and for a ' | |
'fixed version of this code and its dependencies.') | |
utils.set_hardware_rng_ops() | |
rng = random.PRNGKey(random_seed) | |
else: | |
logging.info('Using seed for initialization and dropout RNG: %d', | |
random_seed) | |
rng = random.PRNGKey(random_seed) | |
init_rng, trainer_rng = random.split(rng, 2) | |
# --------------------------------------------------------------------------- | |
# Initialize datasets | |
# --------------------------------------------------------------------------- | |
if (train_dataset_cfg.seed and | |
not (checkpoint_cfg.save or checkpoint_cfg.save.save_dataset)): | |
logging.warning( | |
'Providing a random seed for the train dataset with ' | |
'`checkpoint_train_ds=False` is dangerous since each ' | |
'preemption/restart will cause the dataset to deterministically replay ' | |
'from the beginning.') | |
data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size) | |
ds_shard_id = data_layout.shard_id | |
num_ds_shards = data_layout.num_shards | |
def _verify_matching_vocabs(cfg: utils.DatasetConfig): | |
ds_vocabs = utils.get_vocabulary(cfg) | |
if (ds_vocabs[0] != model.input_vocabulary or | |
ds_vocabs[1] != model.output_vocabulary): | |
raise ValueError(f'Model and Task vocabularies do not match:\n' | |
f' task={cfg.mixture_or_task_name}\n' | |
f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n' | |
f' model.input_vocabulary={model.input_vocabulary}\n' | |
f' model.output_vocabulary={model.output_vocabulary}\n') | |
_verify_matching_vocabs(train_dataset_cfg) | |
train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards, | |
model.FEATURE_CONVERTER_CLS) | |
if isinstance(train_ds, tf.data.Dataset): | |
train_iter = clu.data.TfDatasetIterator(train_ds) | |
elif isinstance(train_ds, clu.data.DatasetIterator): | |
train_iter = train_ds | |
else: | |
raise ValueError( | |
f'get_dataset_fn returned unsupported type {type(train_ds)}.') | |
if train_eval_dataset_cfg: | |
_verify_matching_vocabs(train_eval_dataset_cfg) | |
train_eval_datasets = utils.get_training_eval_datasets( | |
train_eval_dataset_cfg, | |
ds_shard_id, | |
num_ds_shards, | |
eval_steps, | |
model.FEATURE_CONVERTER_CLS, | |
get_dataset_fn=train_eval_get_dataset_fn if train_eval_get_dataset_fn | |
is not None else get_dataset_fn) # type: Mapping[str, tf.data.Dataset] | |
if not train_eval_datasets: | |
logging.warning( | |
'No train_eval datasets loaded from config `train_eval_dataset_cfg`: ' | |
'%s', train_eval_dataset_cfg) | |
else: | |
train_eval_datasets = {} | |
# The manner in which parameters are initialized follows this order of | |
# preference: | |
# 1. From a T5X checkpoint in `model_dir`, if one exists. | |
# 2. From a T5X or TF checkpoint specified by `cfg.path`, if set. | |
# 3. From scratch using `init_fn`. | |
# 1. From a T5X checkpoint in `model_dir`, if one exists. | |
if checkpoint_cfg.restore is not None: | |
state_transforms_for_restore = [ | |
functools.partial(fn, is_resuming=True) | |
for fn in checkpoint_cfg.restore.state_transformation_fns | |
] | |
else: | |
state_transforms_for_restore = [] | |
restore_cfgs = [ | |
utils.RestoreCheckpointConfig( | |
path=model_dir, | |
mode='latest', | |
dtype=checkpoint_cfg.save.dtype, | |
checkpointer_cls=checkpoint_cfg.save.checkpointer_cls, | |
# Restore dataset state if it is being saved. | |
restore_dataset=(checkpoint_cfg.save and | |
checkpoint_cfg.save.save_dataset), | |
state_transformation_fns=state_transforms_for_restore) | |
] | |
# 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set. | |
if checkpoint_cfg.restore: | |
if checkpoint_cfg.restore.mode == 'all': | |
raise ValueError( | |
"Restore checkpoint mode 'all' is not supported in training.") | |
# TODO(dhgarrette): Split "restore" behavior into separate configurations | |
# for the initial restoration for a new run, vs resuming a stopped run. | |
if isinstance(checkpoint_cfg.restore.path, str): | |
restore_cfgs.append(checkpoint_cfg.restore) | |
elif not checkpoint_cfg.restore.path: | |
# `path` is an empty (non-`str`) sequence, so there is nothing to restore. | |
pass | |
else: | |
raise ValueError( | |
'Restore checkpoint config may only have a single path in training.') | |
# Need to use full batch size. | |
input_shapes = { | |
k: (data_layout.batch_size, *v.shape[1:]) | |
for k, v in train_ds.element_spec.items() | |
} | |
input_types = { | |
k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items() | |
} | |
init_or_restore_tick = time.time() | |
train_state_initializer = utils.TrainStateInitializer( | |
optimizer_def=model.optimizer_def, | |
init_fn=model.get_initial_variables, | |
input_shapes=input_shapes, | |
input_types=input_types, | |
partitioner=partitioner) | |
# May be None, empty | |
valid_restore_cfg, restore_paths = utils.get_first_valid_restore_config_and_paths( | |
restore_cfgs) | |
if len(restore_paths) > 1: | |
raise ValueError('Multiple restore paths not permitted in training.') | |
checkpointable_train_iter = ( | |
train_iter.iterator | |
if isinstance(train_iter, clu.data.TfDatasetIterator) else None) | |
checkpoint_manager = utils.LegacyCheckpointManager( | |
checkpoint_cfg.save, | |
valid_restore_cfg, | |
train_state_initializer.global_train_state_shape, | |
partitioner, | |
ds_iter=checkpointable_train_iter, | |
model_dir=model_dir, | |
use_gda=use_gda) | |
train_state = checkpoint_manager.restore( | |
restore_paths, valid_restore_cfg, | |
utils.get_fallback_state( | |
valid_restore_cfg, | |
lambda rng: train_state_initializer.from_scratch(rng).state_dict(), | |
init_rng)) | |
# 3. If no checkpoint to restore, init from scratch. | |
train_state = train_state or train_state_initializer.from_scratch(init_rng) | |
train_state_axes = train_state_initializer.train_state_axes | |
init_or_restore_secs = time.time() - init_or_restore_tick | |
logging.info('Initialize/restore complete (%.2f seconds).', | |
init_or_restore_secs) | |
# Log the variable shapes information and write to a file. | |
log_file = os.path.join(model_dir, 'model-info.txt') | |
utils.log_model_info(log_file, | |
train_state_initializer.global_train_state_shape, | |
partitioner) | |
# Restore step from last checkpoint or set to 0 if training from scratch. | |
host_step = int(utils.get_local_data(train_state.step)) # pytype: disable=attribute-error | |
# --------------------------------------------------------------------------- | |
# Trainer | |
# --------------------------------------------------------------------------- | |
trainer: trainer_lib.BaseTrainer = trainer_cls( | |
model=model, | |
train_state=train_state, | |
partitioner=partitioner, | |
train_state_axes=train_state_axes, | |
eval_names=train_eval_datasets.keys(), | |
summary_dir=model_dir, | |
rng=trainer_rng) | |
del train_state | |
train_metrics = trainer.train_metrics_manager | |
summarize_config_fn(model_dir, train_metrics.summary_writer, host_step) | |
train_metrics.write_scalar('timing/init_or_restore_seconds', | |
init_or_restore_secs, host_step) | |
# ---------------------------------------------------------------------------- | |
# SeqIO (inference-based) evaluation setup | |
# ---------------------------------------------------------------------------- | |
# Init evaluator to set up cached datasets | |
evaluator = None | |
if infer_eval_dataset_cfg is not None: | |
_verify_matching_vocabs(infer_eval_dataset_cfg) | |
evaluator = inference_evaluator_cls( | |
log_dir=os.path.join(model_dir, 'inference_eval'), | |
mixture_or_task_name=infer_eval_dataset_cfg.mixture_or_task_name, | |
feature_converter=model.FEATURE_CONVERTER_CLS(pack=False), | |
eval_split=infer_eval_dataset_cfg.split, | |
use_cached=infer_eval_dataset_cfg.use_cached, | |
seed=infer_eval_dataset_cfg.seed, | |
sequence_length=infer_eval_dataset_cfg.task_feature_lengths, | |
use_memory_cache=infer_eval_dataset_cfg.use_memory_cache) | |
if not evaluator.eval_tasks: | |
# Skip evaluaton. | |
evaluator = None | |
if evaluator is not None: | |
predict_fn = utils.get_infer_fn( | |
infer_step=model.predict_batch, | |
batch_size=infer_eval_dataset_cfg.batch_size, | |
train_state_axes=train_state_axes, | |
partitioner=partitioner) | |
predict_with_aux_fn = utils.get_infer_fn( | |
infer_step=model.predict_batch_with_aux, | |
batch_size=infer_eval_dataset_cfg.batch_size, | |
train_state_axes=train_state_axes, | |
partitioner=partitioner) | |
score_fn = utils.get_infer_fn( | |
infer_step=model.score_batch, | |
batch_size=infer_eval_dataset_cfg.batch_size, | |
train_state_axes=train_state_axes, | |
partitioner=partitioner) | |
if actions is None: | |
actions = {} | |
if set(actions.keys()).difference(_ACTION_KEYS): | |
raise ValueError(f'actions keys must be one of {_ACTION_KEYS}, but got : ' | |
f'{actions.keys()}') | |
# Transform the string key into proper ActionMode enum. | |
actions = {trainer_lib.ActionMode[k]: v for k, v in actions.items()} | |
if concurrent_metrics and actions.get(trainer_lib.ActionMode.INFER_EVAL, | |
None) is not None: | |
logging.warning('Actions for INFER_EVAL will not be triggered when async ' | |
'metrics computation is enabled') | |
if concurrent_metrics and actions.get(trainer_lib.ActionMode.TRAIN, | |
None) is not None: | |
logging.warning('Actions for TRAIN will not be triggered when async ' | |
'metrics computation is enabled') | |
# ---------------------------------------------------------------------------- | |
# Setup Eval Utility Functions | |
# ---------------------------------------------------------------------------- | |
def _run_training_eval(first_run: bool = False): | |
if first_run: | |
logging.info('Compiling training eval loop.') | |
trainer.compile_eval({ | |
task: utils.get_zeros_batch_like_dataset(ds) | |
for task, ds in train_eval_datasets.items() | |
}) | |
logging.info('Computing training evaluation metrics.') | |
eval_batch_iters = { | |
task: ds.as_numpy_iterator() | |
for task, ds in train_eval_datasets.items() | |
} | |
eval_summaries = trainer.eval(eval_batch_iters) | |
trainer.stop_training = run_actions(trainer_lib.ActionMode.TRAIN_EVAL, | |
actions, trainer.train_state, | |
eval_summaries) | |
def _run_inference_eval(): | |
"""Run prediction based inference eval.""" | |
if evaluator is None: | |
return | |
logging.info('Running inference evaluation.') | |
evaluate_tick = time.time() | |
all_metrics, _, _ = evaluator.evaluate( | |
compute_metrics=jax.process_index() == 0, | |
step=host_step, | |
predict_fn=functools.partial( | |
predict_fn, | |
train_state=trainer.train_state, | |
rng=jax.random.PRNGKey(0)), | |
score_fn=functools.partial(score_fn, train_state=trainer.train_state), | |
predict_with_aux_fn=functools.partial( | |
predict_with_aux_fn, | |
train_state=trainer.train_state, | |
rng=jax.random.PRNGKey(0)), | |
) | |
if not concurrent_metrics: | |
# Ensure metrics are finished being computed. | |
all_metrics_done = all_metrics.result() or {} | |
trainer.stop_training = run_actions(trainer_lib.ActionMode.INFER_EVAL, | |
actions, trainer.train_state, | |
all_metrics_done) | |
train_metrics.write_scalar('timing/evaluate_seconds', | |
time.time() - evaluate_tick, host_step) | |
# Optionally run teacher-forcing training eval and SeqIO inference-base eval | |
# before training. Useful for testing how much a model knows before any | |
# finetuning. | |
if run_eval_before_training: | |
if train_eval_datasets: | |
logging.info('Running training eval before training.') | |
_run_training_eval(first_run=True) | |
if evaluator is not None: | |
logging.info('Running inference eval before training.') | |
_run_inference_eval() | |
# ---------------------------------------------------------------------------- | |
# Main training loop | |
# ---------------------------------------------------------------------------- | |
logging.info('Starting training loop.') | |
first_step = host_step | |
if total_steps < first_step: | |
raise ValueError( | |
f'Unexpected total_steps ({total_steps}) < checkpoint step ' | |
f' ({first_step}).') | |
logging.info('Starting main loop over steps %d-%d', first_step, total_steps) | |
steps_per_epoch = min(steps_per_epoch, total_steps) | |
first_epoch = first_step // steps_per_epoch | |
num_epochs = first_epoch + math.ceil( | |
(total_steps - first_step) / steps_per_epoch) | |
logging.info('Training with artificial "epochs" of %d steps.', | |
steps_per_epoch) | |
logging.info('Compiling train loop.') | |
logging.flush() | |
dummy_batch = { | |
k: np.ones(v.shape, v.dtype) for k, v in train_iter.element_spec.items() | |
} | |
trainer.compile_train(dummy_batch) | |
# Main Loop over "epochs". | |
for epoch in range(first_epoch, num_epochs): | |
final_epoch = epoch == num_epochs - 1 | |
logging.info('Epoch %d of %d', epoch, num_epochs) | |
# `stop_training` is requested, break out the main loop immediately. | |
if trainer.stop_training: | |
break | |
logging.info('BEGIN Train loop.') | |
try: | |
# Until the last epoch, `num_steps = steps_per_epoch` | |
num_steps = min(total_steps - host_step, steps_per_epoch) | |
epoch_end_step = host_step + num_steps | |
logging.info('Training for %d steps.', num_steps) | |
while host_step < epoch_end_step: | |
if trainer.stop_training: | |
logging.info('Saving a checkpoint before early stopping...') | |
checkpoint_manager.save(trainer.train_state, | |
checkpoint_cfg.save.state_transformation_fns) | |
logging.info('Stopping training loop early since `stop_training` is ' | |
'requested.') | |
break | |
inner_num_steps = min(epoch_end_step - host_step, stats_period) | |
train_summary = trainer.train( | |
train_iter, inner_num_steps, start_step=host_step) | |
if not concurrent_metrics: | |
# Note that we always pass the dictionary of `tasks` -> summary so | |
# that the actions can be performed without special casing. The only | |
# caveat is that train would need its own special `key` given no | |
# `task` will be applied. | |
trainer.stop_training = run_actions( | |
trainer_lib.ActionMode.TRAIN, actions, trainer.train_state, | |
{TRAIN_METRIC_KEY: train_summary.result()}) | |
host_step += inner_num_steps | |
logging.info('END Train loop.') | |
except trainer_lib.PreemptionError as e: | |
logging.info('Saving emergency checkpoint.') | |
checkpoint_manager.save(trainer.train_state, | |
checkpoint_cfg.save.state_transformation_fns) | |
logging.info('Saving emergency checkpoint done.') | |
raise e | |
step_offset = host_step - first_step | |
# Maybe save a checkpoint. | |
if checkpoint_period and (final_epoch or | |
step_offset % checkpoint_period == 0): | |
# Make sure last train step has completed before starting the clock. | |
train_summary.result() | |
logging.info('Saving checkpoint.') | |
checkpoint_tick = time.time() | |
checkpoint_manager.save(trainer.train_state, | |
checkpoint_cfg.save.state_transformation_fns) | |
checkpoint_tock = time.time() | |
train_metrics.write_scalar('timing/checkpoint_seconds', | |
checkpoint_tock - checkpoint_tick, host_step) | |
is_eval_epoch = eval_period and (final_epoch or | |
step_offset % eval_period == 0) | |
# Training Evaluation (i.e., with teacher forcing). | |
if is_eval_epoch and train_eval_datasets: | |
# Maybe less if final step < period. | |
first_run = step_offset // eval_period <= 1 | |
_run_training_eval(first_run and not run_eval_before_training) | |
# Inference Evaluation (i.e., with decoding or scoring). | |
if evaluator is not None: | |
_run_inference_eval() | |
# Wait until computations are done before exiting | |
logging.info('Finished.') | |
trainer.close() | |
if evaluator: | |
evaluator.close() | |
multihost_utils.sync_global_devices('complete') | |
return host_step, trainer.train_state | |
if __name__ == '__main__': | |
# pylint: disable=g-import-not-at-top | |
from absl import app | |
from absl import flags | |
import gin | |
from t5x import gin_utils | |
# pylint: enable=g-import-not-at-top | |
FLAGS = flags.FLAGS | |
jax.config.parse_flags_with_absl() | |
flags.DEFINE_multi_string( | |
'gin_file', | |
default=None, | |
help='Path to gin configuration file. Multiple paths may be passed and ' | |
'will be imported in the given order, with later configurations ' | |
'overriding earlier ones.') | |
flags.DEFINE_multi_string( | |
'gin_bindings', default=[], help='Individual gin bindings.') | |
flags.DEFINE_list( | |
'gin_search_paths', | |
default=['.'], | |
help='Comma-separated list of gin config path prefixes to be prepended ' | |
'to suffixes given via `--gin_file`. If a file appears in. Only the ' | |
'first prefix that produces a valid path for each suffix will be ' | |
'used.') | |
flags.DEFINE_string( | |
'tfds_data_dir', None, | |
'If set, this directory will be used to store datasets prepared by ' | |
'TensorFlow Datasets that are not available in the public TFDS GCS ' | |
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' | |
'all `Task`s.') | |
flags.DEFINE_list( | |
'seqio_additional_cache_dirs', [], | |
'Directories to search for cached Tasks in addition to defaults.') | |
def main(argv: Sequence[str]): | |
"""Wrapper for pdb post mortems.""" | |
_main(argv) | |
def _main(argv: Sequence[str]): | |
"""True main function.""" | |
if len(argv) > 1: | |
raise app.UsageError('Too many command-line arguments.') | |
if FLAGS.tfds_data_dir: | |
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) | |
seqio.add_global_cache_dirs(FLAGS.seqio_additional_cache_dirs) | |
# Create gin-configurable version of `train`. | |
train_using_gin = gin.configurable(train) | |
gin_utils.parse_gin_flags( | |
# User-provided gin paths take precedence if relative paths conflict. | |
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, | |
FLAGS.gin_file, | |
FLAGS.gin_bindings) | |
train_using_gin() | |
gin_utils.run(main) | |