|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Script to pretrain or finetune in JAX using a SeqIO pipeline. |
|
|
|
""" |
|
import functools |
|
import itertools |
|
import math |
|
import os |
|
import time |
|
from typing import Callable, Iterator, Sequence, Mapping, Tuple, Type, Optional |
|
import subprocess |
|
|
|
|
|
|
|
|
|
os.environ['FLAX_PROFILE'] = 'true' |
|
|
|
os.environ['FLAX_LAZY_RNG'] = 'no' |
|
from absl import logging |
|
from clu import metric_writers |
|
import jax |
|
from jax import random |
|
from jax.experimental import multihost_utils |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import seqio |
|
from t5x import models |
|
from t5x import partitioning |
|
from t5x import train_state as train_state_lib |
|
from t5x import trainer as trainer_lib |
|
from t5x import utils |
|
from t5x import checkpoint_importer |
|
LazyArray = checkpoint_importer.LazyArray |
|
import tensorflow as tf |
|
|
|
|
|
|
|
_DEFAULT_GIN_SEARCH_PATHS = [ |
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
] |
|
PyTreeDef = type(jax.tree_structure(None)) |
|
P = partitioning.PartitionSpec |
|
|
|
TRAIN_METRIC_KEY = 'train' |
|
|
|
_ACTION_KEYS = frozenset(trainer_lib.ActionMode.__members__.keys()) |
|
|
|
|
|
def run_actions( |
|
mode: trainer_lib.ActionMode, actions: trainer_lib.ActionMapType, |
|
train_state: train_state_lib.TrainState, |
|
metrics_by_task: Mapping[str, trainer_lib.MetricValueMapType]) -> bool: |
|
"""Invokes all actions on the given mode on host 0, then broadcasts to all. |
|
|
|
Args: |
|
mode: The mode to run the actions. e.g., if mode is `train`, only actions |
|
configured to run with `train` mode will be invoked. |
|
actions: A mapping of actions that runs after train, eval or infer_eval, to |
|
inspect the model and perform useful operations, e.g., early stopping. |
|
train_state: The current train_state of the trainer. |
|
metrics_by_task: A map of metrics keyed by task name. |
|
|
|
Returns: |
|
A bool indicating whether training should be halted. |
|
|
|
Raises: |
|
RuntimeError: When the metrics processed on host 0 is None. |
|
""" |
|
stop_training = False |
|
if jax.process_index() == 0: |
|
if not metrics_by_task: |
|
raise RuntimeError('Metric is unexpectedly empty on process 0') |
|
for action in actions.get(mode, []): |
|
stop_training |= action.run(train_state, metrics_by_task=metrics_by_task) |
|
|
|
return bool(multihost_utils.broadcast_one_to_all(jnp.array(stop_training))) |
|
|
|
|
|
def train( |
|
*, |
|
model: models.BaseTransformerModel, |
|
train_dataset_cfg: utils.DatasetConfig, |
|
train_eval_dataset_cfg: Optional[utils.DatasetConfig], |
|
infer_eval_dataset_cfg: Optional[utils.DatasetConfig], |
|
checkpoint_cfg: utils.CheckpointConfig, |
|
partitioner: partitioning.BasePartitioner, |
|
trainer_cls: Type[trainer_lib.BaseTrainer], |
|
model_dir: str, |
|
total_steps: int, |
|
eval_steps: int, |
|
eval_period: int, |
|
stats_period: Optional[int] = None, |
|
random_seed: Optional[int], |
|
use_hardware_rng: bool = False, |
|
summarize_config_fn: Callable[[str, metric_writers.MetricWriter, int], |
|
None], |
|
inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator, |
|
get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset, |
|
concurrent_metrics: bool = True, |
|
actions: Optional[Mapping[str, Sequence[trainer_lib.BaseAction]]] = None, |
|
train_eval_get_dataset_fn: Optional[utils.GetDatasetCallable] = None, |
|
run_eval_before_training: bool = False, |
|
hub_model_id: str = None, |
|
) -> Tuple[int, train_state_lib.TrainState]: |
|
"""Train function. |
|
|
|
Args: |
|
model: The model object to use for training. |
|
train_dataset_cfg: Specification for the dataset to train with. |
|
train_eval_dataset_cfg: Specification for the dataset to evaluate with using |
|
the train metrics and no inference (e.g., uses teacher forcing). If None, |
|
train eval is disabled. |
|
infer_eval_dataset_cfg: Specification for the dataset to evaluate with using |
|
the inference metrics (e.g., uses sampled decoding). If None, inference |
|
eval is disabled. |
|
checkpoint_cfg: Specification for saving and restoring model parameters and |
|
dataset state to/from checkpoints. |
|
partitioner: Partitioner for model parameters and data across devices. |
|
trainer_cls: An implementation of BaseTrainer. |
|
model_dir: Path of directory to store checkpoints and metric summaries. |
|
total_steps: The step number to stop training after. The number of actual |
|
steps trained in this run will be this number minus the starting step from |
|
the checkpoint. |
|
eval_steps: The number of batches to process for each train-eval loop. |
|
eval_period: The number of train steps between each evaluation (both |
|
train-eval and infer-eval). |
|
stats_period: The number of train steps between writing scalar stats. If |
|
None, defaults to eval_period. |
|
random_seed: A random seed to use for dropout and initialization. If None, a |
|
fast, non-deterministic hardware-based RNG is used. |
|
use_hardware_rng: Whether to force using the RngBitGenerator based hardware |
|
rng, which takes seeds and acts similarly to software PRNG in that it |
|
should be seed-deterministic. The new RngBitGenerator custom PRNG system |
|
should be reproducible for a given sharding, but the numbers will change |
|
for different shardings of the same model. |
|
summarize_config_fn: A function that takes in the model directory, a |
|
SummaryWriter, and the step number, and writes a summary of the |
|
inference_evaluator_cls: seqio.Evaluator class to use for inference |
|
evaluation, potentially with bound configuration args. |
|
get_dataset_fn: The callable use to get the train and train-eval datasets |
|
based on the DatasetConfig and shard information. |
|
concurrent_metrics: If True, allow metrics computation and logging to |
|
overlap with training. Will likely result in additional TPU memory usage. |
|
actions: A mapping of actions that runs after train, eval or infer_eval, to |
|
inspect the model and perform useful operations, e.g., early stopping. The |
|
key must have a 1:1 mapping to ActionMode enum. For EVAL actions to |
|
actually work, this requires `concurrent_metrics` to be turned off, |
|
since chaining futures and mutating states concurrently might be |
|
error-prone. |
|
train_eval_get_dataset_fn: Optional callable use to get the train-eval |
|
datasets based on the DatasetConfig and shard information. If missing, it |
|
defaults to `get_dataset_fn`. |
|
run_eval_before_training: If True, calculate training eval and inference |
|
eval metrics before training begins. |
|
|
|
Returns: |
|
The tuple of (last_step, last_train_state). |
|
""" |
|
logging.info('Process ID: %d', jax.process_index()) |
|
tf.io.gfile.makedirs(model_dir) |
|
|
|
|
|
|
|
|
|
|
|
eval_enabled = (train_eval_dataset_cfg or infer_eval_dataset_cfg) |
|
eval_period = eval_period if eval_enabled else 0 |
|
checkpoint_period = checkpoint_cfg.save.period if checkpoint_cfg.save else 0 |
|
if eval_period or checkpoint_period: |
|
steps_per_epoch = min(eval_period or np.inf, checkpoint_period or np.inf) |
|
else: |
|
steps_per_epoch = total_steps |
|
stats_period = stats_period or steps_per_epoch |
|
if (eval_period and eval_period % steps_per_epoch or |
|
checkpoint_period and checkpoint_period % steps_per_epoch): |
|
raise ValueError( |
|
f'Checkpoint period ({checkpoint_period}) must evenly divide eval ' |
|
f'period ({eval_period}), or vice-versa.') |
|
|
|
if use_hardware_rng or random_seed is None: |
|
logging.info( |
|
'Using fast RngBitGenerator PRNG for initialization and dropout.') |
|
|
|
if random_seed is None: |
|
random_seed = multihost_utils.broadcast_one_to_all(np.int32(time.time())) |
|
logging.info('Random seed not provided, using RNG seed %s', random_seed) |
|
else: |
|
logging.warning( |
|
'When using hardware RNG with a fixed seed, repeatability is only ' |
|
'guaranteed for fixed hardware and partitioning schemes and for a ' |
|
'fixed version of this code and its dependencies.') |
|
utils.set_hardware_rng_ops() |
|
rng = random.PRNGKey(random_seed) |
|
else: |
|
logging.info('Using seed for initialization and dropout RNG: %d', |
|
random_seed) |
|
rng = random.PRNGKey(random_seed) |
|
|
|
init_rng, trainer_rng = random.split(rng, 2) |
|
|
|
|
|
|
|
|
|
|
|
if (train_dataset_cfg.seed and |
|
not (checkpoint_cfg.save or checkpoint_cfg.save.save_dataset)): |
|
logging.warning( |
|
'Providing a random seed for the train dataset with ' |
|
'`checkpoint_train_ds=False` is dangerous since each ' |
|
'preemption/restart will cause the dataset to deterministically replay ' |
|
'from the beginning.') |
|
|
|
data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size) |
|
ds_shard_id = data_layout.shard_id |
|
num_ds_shards = data_layout.num_shards |
|
|
|
def _verify_matching_vocabs(cfg: utils.DatasetConfig): |
|
ds_vocabs = utils.get_vocabulary(cfg) |
|
if (ds_vocabs[0] != model.input_vocabulary or |
|
ds_vocabs[1] != model.output_vocabulary): |
|
raise ValueError(f'Model and Task vocabularies do not match:\n' |
|
f' task={cfg.mixture_or_task_name}\n' |
|
f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n' |
|
f' model.input_vocabulary={model.input_vocabulary}\n' |
|
f' model.output_vocabulary={model.output_vocabulary}\n') |
|
|
|
_verify_matching_vocabs(train_dataset_cfg) |
|
|
|
train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards, |
|
model.FEATURE_CONVERTER_CLS) |
|
|
|
if train_eval_dataset_cfg: |
|
_verify_matching_vocabs(train_eval_dataset_cfg) |
|
train_eval_datasets = utils.get_training_eval_datasets( |
|
train_eval_dataset_cfg, |
|
ds_shard_id, |
|
num_ds_shards, |
|
eval_steps, |
|
model.FEATURE_CONVERTER_CLS, |
|
get_dataset_fn=train_eval_get_dataset_fn if train_eval_get_dataset_fn |
|
is not None else get_dataset_fn) |
|
if not train_eval_datasets: |
|
logging.warning( |
|
'No train_eval datasets loaded from config `train_eval_dataset_cfg`: ' |
|
'%s', train_eval_dataset_cfg) |
|
else: |
|
train_eval_datasets = {} |
|
|
|
|
|
checkpointable_train_iter: tf.data.Iterator = iter(train_ds) |
|
train_iter: Iterator[trainer_lib.BatchType] = map( |
|
lambda x: jax.tree_map(np.array, x), checkpointable_train_iter) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if checkpoint_cfg.restore is not None: |
|
state_transforms_for_restore = [ |
|
functools.partial(fn, is_resuming=True) |
|
for fn in checkpoint_cfg.restore.state_transformation_fns |
|
] |
|
else: |
|
state_transforms_for_restore = [] |
|
restore_cfgs = [ |
|
utils.RestoreCheckpointConfig( |
|
path=model_dir, |
|
mode='latest', |
|
dtype=checkpoint_cfg.save.dtype, |
|
checkpointer_cls=checkpoint_cfg.save.checkpointer_cls, |
|
|
|
restore_dataset=(checkpoint_cfg.save and |
|
checkpoint_cfg.save.save_dataset), |
|
state_transformation_fns=state_transforms_for_restore) |
|
] |
|
|
|
if checkpoint_cfg.restore: |
|
if checkpoint_cfg.restore.mode == 'all': |
|
raise ValueError( |
|
"Restore checkpoint mode 'all' is not supported in training.") |
|
|
|
|
|
|
|
if isinstance(checkpoint_cfg.restore.path, str): |
|
restore_cfgs.append(checkpoint_cfg.restore) |
|
elif not checkpoint_cfg.restore.path: |
|
|
|
pass |
|
else: |
|
raise ValueError( |
|
'Restore checkpoint config may only have a single path in training.') |
|
|
|
|
|
input_shapes = { |
|
k: (data_layout.batch_size, *v.shape[1:]) |
|
for k, v in train_ds.element_spec.items() |
|
} |
|
input_types = { |
|
k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items() |
|
} |
|
init_or_restore_tick = time.time() |
|
train_state_initializer = utils.TrainStateInitializer( |
|
optimizer_def=model.optimizer_def, |
|
init_fn=model.get_initial_variables, |
|
input_shapes=input_shapes, |
|
input_types=input_types, |
|
partitioner=partitioner) |
|
|
|
train_state = train_state_initializer.from_checkpoint_or_scratch( |
|
restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter) |
|
train_state_axes = train_state_initializer.train_state_axes |
|
init_or_restore_secs = time.time() - init_or_restore_tick |
|
logging.info('Initialize/restore complete (%.2f seconds).', |
|
init_or_restore_secs) |
|
|
|
|
|
log_file = os.path.join(model_dir, 'model-info.txt') |
|
utils.log_model_info(log_file, |
|
train_state_initializer.global_train_state_shape, |
|
partitioner) |
|
|
|
if checkpoint_period: |
|
checkpointer = checkpoint_cfg.save.checkpointer_cls( |
|
train_state=train_state_initializer.global_train_state_shape, |
|
partitioner=partitioner, |
|
checkpoints_dir=model_dir, |
|
dataset_iterator=(checkpointable_train_iter |
|
if checkpoint_cfg.save.save_dataset else None), |
|
save_dtype=checkpoint_cfg.save.dtype, |
|
keep=checkpoint_cfg.save.keep) |
|
|
|
|
|
|
|
host_step = int(train_state.step) |
|
|
|
|
|
|
|
|
|
|
|
trainer: trainer_lib.BaseTrainer = trainer_cls( |
|
model=model, |
|
train_state=train_state, |
|
partitioner=partitioner, |
|
train_state_axes=train_state_axes, |
|
eval_names=train_eval_datasets.keys(), |
|
summary_dir=model_dir, |
|
rng=trainer_rng) |
|
del train_state |
|
|
|
train_metrics = trainer.train_metrics_manager |
|
summarize_config_fn(model_dir, train_metrics.summary_writer, host_step) |
|
|
|
train_metrics.write_scalar('timing/init_or_restore_seconds', |
|
init_or_restore_secs, host_step) |
|
|
|
|
|
|
|
|
|
|
|
evaluator = None |
|
if infer_eval_dataset_cfg is not None: |
|
_verify_matching_vocabs(infer_eval_dataset_cfg) |
|
evaluator = inference_evaluator_cls( |
|
log_dir=os.path.join(model_dir, 'inference_eval'), |
|
mixture_or_task_name=infer_eval_dataset_cfg.mixture_or_task_name, |
|
feature_converter=model.FEATURE_CONVERTER_CLS(pack=False), |
|
eval_split=infer_eval_dataset_cfg.split, |
|
use_cached=infer_eval_dataset_cfg.use_cached, |
|
seed=infer_eval_dataset_cfg.seed, |
|
sequence_length=infer_eval_dataset_cfg.task_feature_lengths, |
|
use_memory_cache=infer_eval_dataset_cfg.use_memory_cache) |
|
if not evaluator.eval_tasks: |
|
|
|
evaluator = None |
|
|
|
if evaluator is not None: |
|
predict_fn = utils.get_infer_fn( |
|
infer_step=model.predict_batch, |
|
batch_size=infer_eval_dataset_cfg.batch_size, |
|
train_state_axes=train_state_axes, |
|
partitioner=partitioner) |
|
|
|
score_fn = utils.get_infer_fn( |
|
infer_step=model.score_batch, |
|
batch_size=infer_eval_dataset_cfg.batch_size, |
|
train_state_axes=train_state_axes, |
|
partitioner=partitioner) |
|
|
|
if actions is None: |
|
actions = {} |
|
|
|
if set(actions.keys()).difference(_ACTION_KEYS): |
|
raise ValueError(f'actions keys must be one of {_ACTION_KEYS}, but got : ' |
|
f'{actions.keys()}') |
|
|
|
|
|
actions = {trainer_lib.ActionMode[k]: v for k, v in actions.items()} |
|
|
|
if concurrent_metrics and actions.get(trainer_lib.ActionMode.INFER_EVAL, |
|
None) is not None: |
|
logging.warning('Actions for INFER_EVAL will not be triggered when async ' |
|
'metrics computation is enabled') |
|
if concurrent_metrics and actions.get(trainer_lib.ActionMode.TRAIN, |
|
None) is not None: |
|
logging.warning('Actions for TRAIN will not be triggered when async ' |
|
'metrics computation is enabled') |
|
|
|
|
|
|
|
|
|
def _run_training_eval(first_run: bool = False): |
|
if first_run: |
|
logging.info('Compiling training eval loop.') |
|
trainer.compile_eval({ |
|
task: utils.get_zeros_batch_like_dataset(ds) |
|
for task, ds in train_eval_datasets.items() |
|
}) |
|
logging.info('Computing training evaluation metrics.') |
|
eval_batch_iters = { |
|
task: ds.as_numpy_iterator() |
|
for task, ds in train_eval_datasets.items() |
|
} |
|
eval_summaries = trainer.eval(eval_batch_iters) |
|
trainer.stop_training = run_actions(trainer_lib.ActionMode.TRAIN_EVAL, |
|
actions, trainer.train_state, |
|
eval_summaries) |
|
|
|
def _run_inference_eval(): |
|
"""Run prediction based inference eval.""" |
|
if evaluator is None: |
|
return |
|
logging.info('Running inference evaluation.') |
|
evaluate_tick = time.time() |
|
all_metrics, _, _ = evaluator.evaluate( |
|
compute_metrics=jax.process_index() == 0, |
|
step=host_step, |
|
predict_fn=functools.partial( |
|
predict_fn, |
|
train_state=trainer.train_state, |
|
rng=jax.random.PRNGKey(0)), |
|
score_fn=functools.partial(score_fn, train_state=trainer.train_state)) |
|
if not concurrent_metrics: |
|
|
|
all_metrics_done = all_metrics.result() or {} |
|
trainer.stop_training = run_actions(trainer_lib.ActionMode.INFER_EVAL, |
|
actions, trainer.train_state, |
|
all_metrics_done) |
|
train_metrics.write_scalar('timing/evaluate_seconds', |
|
time.time() - evaluate_tick, host_step) |
|
|
|
|
|
|
|
|
|
if run_eval_before_training: |
|
if train_eval_datasets: |
|
logging.info('Running training eval before training.') |
|
_run_training_eval(first_run=True) |
|
if evaluator is not None: |
|
logging.info('Running inference eval before training.') |
|
_run_inference_eval() |
|
|
|
|
|
|
|
|
|
logging.info('Starting training loop.') |
|
|
|
first_step = host_step |
|
|
|
if total_steps < first_step: |
|
raise ValueError( |
|
f'Unexpected total_steps ({total_steps}) < checkpoint step ' |
|
f' ({first_step}).') |
|
|
|
logging.info('Starting main loop over steps %d-%d', first_step, total_steps) |
|
|
|
steps_per_epoch = min(steps_per_epoch, total_steps) |
|
first_epoch = first_step // steps_per_epoch |
|
num_epochs = first_epoch + math.ceil( |
|
(total_steps - first_step) / steps_per_epoch) |
|
logging.info('Training with artificial "epochs" of %d steps.', |
|
steps_per_epoch) |
|
|
|
|
|
logging.info('Kickstarting train dataset prefetch.') |
|
logging.flush() |
|
|
|
ds_tick = time.time() |
|
|
|
first_batch = next(train_iter) |
|
|
|
train_iter = itertools.chain([first_batch], train_iter) |
|
train_metrics.write_scalar('timing/dataset_warmup_seconds', |
|
time.time() - ds_tick, host_step) |
|
logging.info('Compiling train loop.') |
|
logging.flush() |
|
trainer.compile_train(first_batch) |
|
|
|
|
|
for epoch in range(first_epoch, num_epochs): |
|
final_epoch = epoch == num_epochs - 1 |
|
logging.info('Epoch %d of %d', epoch, num_epochs) |
|
|
|
|
|
if trainer.stop_training: |
|
break |
|
|
|
logging.info('BEGIN Train loop.') |
|
try: |
|
|
|
num_steps = min(total_steps - host_step, steps_per_epoch) |
|
epoch_end_step = host_step + num_steps |
|
logging.info('Training for %d steps.', num_steps) |
|
while host_step < epoch_end_step: |
|
if trainer.stop_training: |
|
logging.info('Saving a checkpoint before early stopping...') |
|
checkpointer.save(trainer.train_state, |
|
checkpoint_cfg.save.state_transformation_fns) |
|
|
|
if hub_model_id: |
|
|
|
checkpoint_step = trainer.train_state.step |
|
checkpoint_step = checkpoint_step.get() if isinstance(checkpoint_step, LazyArray) else checkpoint_step |
|
checkpoint_step = int(checkpoint_step) |
|
config_path = os.path.join(model_dir, 'config.json') |
|
subprocess.run(["python3", "convert_t5x_checkpoint_to_flax.py", f"--t5x_checkpoint_path='checkpoint_{checkpoint_step}'/'", f'--config_name="{config_path}"', "--flax_dump_folder_path='./'"]) |
|
subprocess.run("git lfs prune --verify-remote", shell=True) |
|
subprocess.run("git add .", shell=True) |
|
subprocess.run(f'git commit -m "Saving weights and logs of step {checkpoint_step}"', shell=True) |
|
subprocess.Popen("git push", shell=True) |
|
|
|
logging.info('Stopping training loop early since `stop_training` is ' |
|
'requested.') |
|
break |
|
|
|
inner_num_steps = min(epoch_end_step - host_step, stats_period) |
|
train_summary = trainer.train( |
|
train_iter, inner_num_steps, start_step=host_step) |
|
if not concurrent_metrics: |
|
|
|
|
|
|
|
|
|
trainer.stop_training = run_actions( |
|
trainer_lib.ActionMode.TRAIN, actions, trainer.train_state, |
|
{TRAIN_METRIC_KEY: train_summary.result()}) |
|
|
|
host_step += inner_num_steps |
|
logging.info('END Train loop.') |
|
except trainer_lib.PreemptionError as e: |
|
logging.info('Saving emergency checkpoint.') |
|
checkpointer.save(trainer.train_state, |
|
checkpoint_cfg.save.state_transformation_fns) |
|
logging.info('Saving emergency checkpoint done.') |
|
raise e |
|
|
|
step_offset = host_step - first_step |
|
|
|
is_eval_epoch = eval_period and (final_epoch or |
|
step_offset % eval_period == 0) |
|
|
|
|
|
if is_eval_epoch and train_eval_datasets: |
|
|
|
first_run = step_offset // eval_period <= 1 |
|
_run_training_eval(first_run and not run_eval_before_training) |
|
|
|
|
|
if checkpoint_period and (final_epoch or |
|
step_offset % checkpoint_period == 0): |
|
|
|
train_summary.result() |
|
logging.info('Saving checkpoint.') |
|
checkpoint_tick = time.time() |
|
checkpointer.save(trainer.train_state, |
|
checkpoint_cfg.save.state_transformation_fns) |
|
checkpoint_tock = time.time() |
|
train_metrics.write_scalar('timing/checkpoint_seconds', |
|
checkpoint_tock - checkpoint_tick, host_step) |
|
|
|
if hub_model_id: |
|
|
|
checkpoint_step = trainer.train_state.step |
|
checkpoint_step = checkpoint_step.get() if isinstance(checkpoint_step, LazyArray) else checkpoint_step |
|
checkpoint_step = int(checkpoint_step) |
|
config_path = os.path.join(model_dir, 'config.json') |
|
subprocess.run(["python3", "convert_t5x_checkpoint_to_flax.py", f"--t5x_checkpoint_path='checkpoint_{checkpoint_step}'/'", f'--config_name="{config_path}"', "--flax_dump_folder_path='./'"]) |
|
subprocess.run("git lfs prune --verify-remote", shell=True) |
|
subprocess.run("git add .", shell=True) |
|
subprocess.run(f'git commit -m "Saving weights and logs of step {checkpoint_step}"', shell=True) |
|
subprocess.Popen("git push", shell=True) |
|
|
|
|
|
if evaluator is not None: |
|
_run_inference_eval() |
|
|
|
|
|
logging.info('Finished.') |
|
trainer.close() |
|
if evaluator: |
|
evaluator.close() |
|
multihost_utils.sync_global_devices('complete') |
|
|
|
return host_step, trainer.train_state |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
from absl import app |
|
from absl import flags |
|
import gin |
|
from t5x import gin_utils |
|
|
|
|
|
FLAGS = flags.FLAGS |
|
|
|
jax.config.parse_flags_with_absl() |
|
|
|
flags.DEFINE_multi_string( |
|
'gin_file', |
|
default=None, |
|
help='Path to gin configuration file. Multiple paths may be passed and ' |
|
'will be imported in the given order, with later configurations ' |
|
'overriding earlier ones.') |
|
|
|
flags.DEFINE_multi_string( |
|
'gin_bindings', default=[], help='Individual gin bindings.') |
|
|
|
flags.DEFINE_list( |
|
'gin_search_paths', |
|
default=['.'], |
|
help='Comma-separated list of gin config path prefixes to be prepended ' |
|
'to suffixes given via `--gin_file`. If a file appears in. Only the ' |
|
'first prefix that produces a valid path for each suffix will be ' |
|
'used.') |
|
|
|
flags.DEFINE_string( |
|
'tfds_data_dir', None, |
|
'If set, this directory will be used to store datasets prepared by ' |
|
'TensorFlow Datasets that are not available in the public TFDS GCS ' |
|
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' |
|
'all `Task`s.') |
|
|
|
flags.DEFINE_list( |
|
'seqio_additional_cache_dirs', [], |
|
'Directories to search for cached Tasks in addition to defaults.') |
|
|
|
|
|
|
|
def main(argv: Sequence[str]): |
|
"""Wrapper for pdb post mortems.""" |
|
_main(argv) |
|
|
|
def _main(argv: Sequence[str]): |
|
"""True main function.""" |
|
if len(argv) > 1: |
|
raise app.UsageError('Too many command-line arguments.') |
|
|
|
if FLAGS.tfds_data_dir: |
|
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) |
|
|
|
seqio.add_global_cache_dirs(FLAGS.seqio_additional_cache_dirs) |
|
|
|
|
|
train_using_gin = gin.configurable(train) |
|
|
|
gin_utils.parse_gin_flags( |
|
|
|
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, |
|
FLAGS.gin_file, |
|
FLAGS.gin_bindings) |
|
train_using_gin() |
|
|
|
gin_utils.run(main) |