juancopi81's picture
Add t5x and mt3 models
b100e1c
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Script to pretrain or finetune in JAX using a SeqIO pipeline.
"""
import functools
import math
import os
import time
from typing import Callable, Sequence, Mapping, Tuple, Type, Optional
# Set Linen to add profiling information when constructing Modules.
# Must be set before flax imports.
# pylint:disable=g-import-not-at-top
os.environ['FLAX_PROFILE'] = 'true'
# TODO(adarob): Re-enable once users are notified and tests are updated.
os.environ['FLAX_LAZY_RNG'] = 'no'
from absl import logging
from clu import metric_writers
import clu.data
import jax
from jax import random
from jax.experimental import multihost_utils
import jax.numpy as jnp
import numpy as np
import seqio
from t5x import models
from t5x import partitioning
from t5x import train_state as train_state_lib
from t5x import trainer as trainer_lib
from t5x import utils
import tensorflow as tf
# Automatically search for gin files relative to the T5X package.
_DEFAULT_GIN_SEARCH_PATHS = [
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
]
PyTreeDef = type(jax.tree_structure(None))
P = partitioning.PartitionSpec
# Special key that used to distinguish train metrics.
TRAIN_METRIC_KEY = 'train'
# String keys that is acceptable from config.
_ACTION_KEYS = frozenset(trainer_lib.ActionMode.__members__.keys())
def run_actions(
mode: trainer_lib.ActionMode, actions: trainer_lib.ActionMapType,
train_state: train_state_lib.TrainState,
metrics_by_task: Mapping[str, trainer_lib.MetricValueMapType]) -> bool:
"""Invokes all actions on the given mode on host 0, then broadcasts to all.
Args:
mode: The mode to run the actions. e.g., if mode is `train`, only actions
configured to run with `train` mode will be invoked.
actions: A mapping of actions that runs after train, eval or infer_eval, to
inspect the model and perform useful operations, e.g., early stopping.
train_state: The current train_state of the trainer.
metrics_by_task: A map of metrics keyed by task name.
Returns:
A bool indicating whether training should be halted.
Raises:
RuntimeError: When the metrics processed on host 0 is None.
"""
stop_training = False
if jax.process_index() == 0:
if not metrics_by_task:
raise RuntimeError('Metric is unexpectedly empty on process 0')
for action in actions.get(mode, []):
stop_training |= action.run(train_state, metrics_by_task=metrics_by_task)
# Broadcast result from host 0 to others.
return bool(multihost_utils.broadcast_one_to_all(jnp.array(stop_training)))
def train(
*,
model: models.BaseTransformerModel,
train_dataset_cfg: utils.DatasetConfig,
train_eval_dataset_cfg: Optional[utils.DatasetConfig],
infer_eval_dataset_cfg: Optional[utils.DatasetConfig],
checkpoint_cfg: utils.CheckpointConfig,
partitioner: partitioning.BasePartitioner,
trainer_cls: Type[trainer_lib.BaseTrainer],
model_dir: str,
total_steps: int,
eval_steps: int,
eval_period: int,
stats_period: Optional[int] = None,
random_seed: Optional[int],
use_hardware_rng: bool = False,
summarize_config_fn: Callable[[str, metric_writers.MetricWriter, int],
None],
inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator,
get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset,
concurrent_metrics: bool = True,
actions: Optional[Mapping[str, Sequence[trainer_lib.BaseAction]]] = None,
train_eval_get_dataset_fn: Optional[utils.GetDatasetCallable] = None,
run_eval_before_training: bool = False,
use_gda: bool = False) -> Tuple[int, train_state_lib.TrainState]:
"""Train function.
Args:
model: The model object to use for training.
train_dataset_cfg: Specification for the dataset to train with.
train_eval_dataset_cfg: Specification for the dataset to evaluate with using
the train metrics and no inference (e.g., uses teacher forcing). If None,
train eval is disabled.
infer_eval_dataset_cfg: Specification for the dataset to evaluate with using
the inference metrics (e.g., uses sampled decoding). If None, inference
eval is disabled.
checkpoint_cfg: Specification for saving and restoring model parameters and
dataset state to/from checkpoints.
partitioner: Partitioner for model parameters and data across devices.
trainer_cls: An implementation of BaseTrainer.
model_dir: Path of directory to store checkpoints and metric summaries.
total_steps: The step number to stop training after. The number of actual
steps trained in this run will be this number minus the starting step from
the checkpoint.
eval_steps: The number of batches to process for each train-eval loop.
eval_period: The number of train steps between each evaluation (both
train-eval and infer-eval).
stats_period: The number of train steps between writing scalar stats. If
None, defaults to eval_period.
random_seed: A random seed to use for dropout and initialization. If None, a
fast, non-deterministic hardware-based RNG is used.
use_hardware_rng: Whether to force using the RngBitGenerator based hardware
rng, which takes seeds and acts similarly to software PRNG in that it
should be seed-deterministic. The new RngBitGenerator custom PRNG system
should be reproducible for a given sharding, but the numbers will change
for different shardings of the same model.
summarize_config_fn: A function that takes in the model directory, a
SummaryWriter, and the step number, and writes a summary of the
inference_evaluator_cls: seqio.Evaluator class to use for inference
evaluation, potentially with bound configuration args.
get_dataset_fn: The callable use to get the train and train-eval datasets
based on the DatasetConfig and shard information.
concurrent_metrics: If True, allow metrics computation and logging to
overlap with training. Will likely result in additional TPU memory usage.
actions: A mapping of actions that runs after train, eval or infer_eval, to
inspect the model and perform useful operations, e.g., early stopping. The
key must have a 1:1 mapping to ActionMode enum. For EVAL actions to
actually work, this requires `concurrent_metrics` to be turned off, since
chaining futures and mutating states concurrently might be error-prone.
train_eval_get_dataset_fn: Optional callable use to get the train-eval
datasets based on the DatasetConfig and shard information. If missing, it
defaults to `get_dataset_fn`.
run_eval_before_training: If True, calculate training eval and inference
eval metrics before training begins.
use_gda: if True, uses GlobalDeviceArray. Experimental feature.
Returns:
The tuple of (last_step, last_train_state).
"""
logging.info('Process ID: %d', jax.process_index())
tf.io.gfile.makedirs(model_dir)
jax.config.update('jax_parallel_functions_output_gda', use_gda)
# Each "epoch" of the training loop should be the min of the eval period,
# checkpoint period or the full training.
# We compute here to ensure that the eval period and checkpoint period are
# divisible by this number, otherwise we fail.
eval_enabled = (train_eval_dataset_cfg or infer_eval_dataset_cfg)
eval_period = eval_period if eval_enabled else 0
checkpoint_period = checkpoint_cfg.save.period if checkpoint_cfg.save else 0
if eval_period or checkpoint_period:
steps_per_epoch = min(eval_period or np.inf, checkpoint_period or np.inf)
else:
steps_per_epoch = total_steps
stats_period = stats_period or steps_per_epoch
if (eval_period and eval_period % steps_per_epoch or
checkpoint_period and checkpoint_period % steps_per_epoch):
raise ValueError(
f'Checkpoint period ({checkpoint_period}) must evenly divide eval '
f'period ({eval_period}), or vice-versa.')
if use_hardware_rng or random_seed is None:
logging.info(
'Using fast RngBitGenerator PRNG for initialization and dropout.')
if random_seed is None:
random_seed = multihost_utils.broadcast_one_to_all(np.int32(time.time()))
logging.info('Random seed not provided, using RNG seed %s', random_seed)
else:
logging.warning(
'When using hardware RNG with a fixed seed, repeatability is only '
'guaranteed for fixed hardware and partitioning schemes and for a '
'fixed version of this code and its dependencies.')
utils.set_hardware_rng_ops()
rng = random.PRNGKey(random_seed)
else:
logging.info('Using seed for initialization and dropout RNG: %d',
random_seed)
rng = random.PRNGKey(random_seed)
init_rng, trainer_rng = random.split(rng, 2)
# ---------------------------------------------------------------------------
# Initialize datasets
# ---------------------------------------------------------------------------
if (train_dataset_cfg.seed and
not (checkpoint_cfg.save or checkpoint_cfg.save.save_dataset)):
logging.warning(
'Providing a random seed for the train dataset with '
'`checkpoint_train_ds=False` is dangerous since each '
'preemption/restart will cause the dataset to deterministically replay '
'from the beginning.')
data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size)
ds_shard_id = data_layout.shard_id
num_ds_shards = data_layout.num_shards
def _verify_matching_vocabs(cfg: utils.DatasetConfig):
ds_vocabs = utils.get_vocabulary(cfg)
if (ds_vocabs[0] != model.input_vocabulary or
ds_vocabs[1] != model.output_vocabulary):
raise ValueError(f'Model and Task vocabularies do not match:\n'
f' task={cfg.mixture_or_task_name}\n'
f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n'
f' model.input_vocabulary={model.input_vocabulary}\n'
f' model.output_vocabulary={model.output_vocabulary}\n')
_verify_matching_vocabs(train_dataset_cfg)
train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
model.FEATURE_CONVERTER_CLS)
if isinstance(train_ds, tf.data.Dataset):
train_iter = clu.data.TfDatasetIterator(train_ds)
elif isinstance(train_ds, clu.data.DatasetIterator):
train_iter = train_ds
else:
raise ValueError(
f'get_dataset_fn returned unsupported type {type(train_ds)}.')
if train_eval_dataset_cfg:
_verify_matching_vocabs(train_eval_dataset_cfg)
train_eval_datasets = utils.get_training_eval_datasets(
train_eval_dataset_cfg,
ds_shard_id,
num_ds_shards,
eval_steps,
model.FEATURE_CONVERTER_CLS,
get_dataset_fn=train_eval_get_dataset_fn if train_eval_get_dataset_fn
is not None else get_dataset_fn) # type: Mapping[str, tf.data.Dataset]
if not train_eval_datasets:
logging.warning(
'No train_eval datasets loaded from config `train_eval_dataset_cfg`: '
'%s', train_eval_dataset_cfg)
else:
train_eval_datasets = {}
# The manner in which parameters are initialized follows this order of
# preference:
# 1. From a T5X checkpoint in `model_dir`, if one exists.
# 2. From a T5X or TF checkpoint specified by `cfg.path`, if set.
# 3. From scratch using `init_fn`.
# 1. From a T5X checkpoint in `model_dir`, if one exists.
if checkpoint_cfg.restore is not None:
state_transforms_for_restore = [
functools.partial(fn, is_resuming=True)
for fn in checkpoint_cfg.restore.state_transformation_fns
]
else:
state_transforms_for_restore = []
restore_cfgs = [
utils.RestoreCheckpointConfig(
path=model_dir,
mode='latest',
dtype=checkpoint_cfg.save.dtype,
checkpointer_cls=checkpoint_cfg.save.checkpointer_cls,
# Restore dataset state if it is being saved.
restore_dataset=(checkpoint_cfg.save and
checkpoint_cfg.save.save_dataset),
state_transformation_fns=state_transforms_for_restore)
]
# 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set.
if checkpoint_cfg.restore:
if checkpoint_cfg.restore.mode == 'all':
raise ValueError(
"Restore checkpoint mode 'all' is not supported in training.")
# TODO(dhgarrette): Split "restore" behavior into separate configurations
# for the initial restoration for a new run, vs resuming a stopped run.
if isinstance(checkpoint_cfg.restore.path, str):
restore_cfgs.append(checkpoint_cfg.restore)
elif not checkpoint_cfg.restore.path:
# `path` is an empty (non-`str`) sequence, so there is nothing to restore.
pass
else:
raise ValueError(
'Restore checkpoint config may only have a single path in training.')
# Need to use full batch size.
input_shapes = {
k: (data_layout.batch_size, *v.shape[1:])
for k, v in train_ds.element_spec.items()
}
input_types = {
k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items()
}
init_or_restore_tick = time.time()
train_state_initializer = utils.TrainStateInitializer(
optimizer_def=model.optimizer_def,
init_fn=model.get_initial_variables,
input_shapes=input_shapes,
input_types=input_types,
partitioner=partitioner)
# May be None, empty
valid_restore_cfg, restore_paths = utils.get_first_valid_restore_config_and_paths(
restore_cfgs)
if len(restore_paths) > 1:
raise ValueError('Multiple restore paths not permitted in training.')
checkpointable_train_iter = (
train_iter.iterator
if isinstance(train_iter, clu.data.TfDatasetIterator) else None)
checkpoint_manager = utils.LegacyCheckpointManager(
checkpoint_cfg.save,
valid_restore_cfg,
train_state_initializer.global_train_state_shape,
partitioner,
ds_iter=checkpointable_train_iter,
model_dir=model_dir,
use_gda=use_gda)
train_state = checkpoint_manager.restore(
restore_paths, valid_restore_cfg,
utils.get_fallback_state(
valid_restore_cfg,
lambda rng: train_state_initializer.from_scratch(rng).state_dict(),
init_rng))
# 3. If no checkpoint to restore, init from scratch.
train_state = train_state or train_state_initializer.from_scratch(init_rng)
train_state_axes = train_state_initializer.train_state_axes
init_or_restore_secs = time.time() - init_or_restore_tick
logging.info('Initialize/restore complete (%.2f seconds).',
init_or_restore_secs)
# Log the variable shapes information and write to a file.
log_file = os.path.join(model_dir, 'model-info.txt')
utils.log_model_info(log_file,
train_state_initializer.global_train_state_shape,
partitioner)
# Restore step from last checkpoint or set to 0 if training from scratch.
host_step = int(utils.get_local_data(train_state.step)) # pytype: disable=attribute-error
# ---------------------------------------------------------------------------
# Trainer
# ---------------------------------------------------------------------------
trainer: trainer_lib.BaseTrainer = trainer_cls(
model=model,
train_state=train_state,
partitioner=partitioner,
train_state_axes=train_state_axes,
eval_names=train_eval_datasets.keys(),
summary_dir=model_dir,
rng=trainer_rng)
del train_state
train_metrics = trainer.train_metrics_manager
summarize_config_fn(model_dir, train_metrics.summary_writer, host_step)
train_metrics.write_scalar('timing/init_or_restore_seconds',
init_or_restore_secs, host_step)
# ----------------------------------------------------------------------------
# SeqIO (inference-based) evaluation setup
# ----------------------------------------------------------------------------
# Init evaluator to set up cached datasets
evaluator = None
if infer_eval_dataset_cfg is not None:
_verify_matching_vocabs(infer_eval_dataset_cfg)
evaluator = inference_evaluator_cls(
log_dir=os.path.join(model_dir, 'inference_eval'),
mixture_or_task_name=infer_eval_dataset_cfg.mixture_or_task_name,
feature_converter=model.FEATURE_CONVERTER_CLS(pack=False),
eval_split=infer_eval_dataset_cfg.split,
use_cached=infer_eval_dataset_cfg.use_cached,
seed=infer_eval_dataset_cfg.seed,
sequence_length=infer_eval_dataset_cfg.task_feature_lengths,
use_memory_cache=infer_eval_dataset_cfg.use_memory_cache)
if not evaluator.eval_tasks:
# Skip evaluaton.
evaluator = None
if evaluator is not None:
predict_fn = utils.get_infer_fn(
infer_step=model.predict_batch,
batch_size=infer_eval_dataset_cfg.batch_size,
train_state_axes=train_state_axes,
partitioner=partitioner)
predict_with_aux_fn = utils.get_infer_fn(
infer_step=model.predict_batch_with_aux,
batch_size=infer_eval_dataset_cfg.batch_size,
train_state_axes=train_state_axes,
partitioner=partitioner)
score_fn = utils.get_infer_fn(
infer_step=model.score_batch,
batch_size=infer_eval_dataset_cfg.batch_size,
train_state_axes=train_state_axes,
partitioner=partitioner)
if actions is None:
actions = {}
if set(actions.keys()).difference(_ACTION_KEYS):
raise ValueError(f'actions keys must be one of {_ACTION_KEYS}, but got : '
f'{actions.keys()}')
# Transform the string key into proper ActionMode enum.
actions = {trainer_lib.ActionMode[k]: v for k, v in actions.items()}
if concurrent_metrics and actions.get(trainer_lib.ActionMode.INFER_EVAL,
None) is not None:
logging.warning('Actions for INFER_EVAL will not be triggered when async '
'metrics computation is enabled')
if concurrent_metrics and actions.get(trainer_lib.ActionMode.TRAIN,
None) is not None:
logging.warning('Actions for TRAIN will not be triggered when async '
'metrics computation is enabled')
# ----------------------------------------------------------------------------
# Setup Eval Utility Functions
# ----------------------------------------------------------------------------
def _run_training_eval(first_run: bool = False):
if first_run:
logging.info('Compiling training eval loop.')
trainer.compile_eval({
task: utils.get_zeros_batch_like_dataset(ds)
for task, ds in train_eval_datasets.items()
})
logging.info('Computing training evaluation metrics.')
eval_batch_iters = {
task: ds.as_numpy_iterator()
for task, ds in train_eval_datasets.items()
}
eval_summaries = trainer.eval(eval_batch_iters)
trainer.stop_training = run_actions(trainer_lib.ActionMode.TRAIN_EVAL,
actions, trainer.train_state,
eval_summaries)
def _run_inference_eval():
"""Run prediction based inference eval."""
if evaluator is None:
return
logging.info('Running inference evaluation.')
evaluate_tick = time.time()
all_metrics, _, _ = evaluator.evaluate(
compute_metrics=jax.process_index() == 0,
step=host_step,
predict_fn=functools.partial(
predict_fn,
train_state=trainer.train_state,
rng=jax.random.PRNGKey(0)),
score_fn=functools.partial(score_fn, train_state=trainer.train_state),
predict_with_aux_fn=functools.partial(
predict_with_aux_fn,
train_state=trainer.train_state,
rng=jax.random.PRNGKey(0)),
)
if not concurrent_metrics:
# Ensure metrics are finished being computed.
all_metrics_done = all_metrics.result() or {}
trainer.stop_training = run_actions(trainer_lib.ActionMode.INFER_EVAL,
actions, trainer.train_state,
all_metrics_done)
train_metrics.write_scalar('timing/evaluate_seconds',
time.time() - evaluate_tick, host_step)
# Optionally run teacher-forcing training eval and SeqIO inference-base eval
# before training. Useful for testing how much a model knows before any
# finetuning.
if run_eval_before_training:
if train_eval_datasets:
logging.info('Running training eval before training.')
_run_training_eval(first_run=True)
if evaluator is not None:
logging.info('Running inference eval before training.')
_run_inference_eval()
# ----------------------------------------------------------------------------
# Main training loop
# ----------------------------------------------------------------------------
logging.info('Starting training loop.')
first_step = host_step
if total_steps < first_step:
raise ValueError(
f'Unexpected total_steps ({total_steps}) < checkpoint step '
f' ({first_step}).')
logging.info('Starting main loop over steps %d-%d', first_step, total_steps)
steps_per_epoch = min(steps_per_epoch, total_steps)
first_epoch = first_step // steps_per_epoch
num_epochs = first_epoch + math.ceil(
(total_steps - first_step) / steps_per_epoch)
logging.info('Training with artificial "epochs" of %d steps.',
steps_per_epoch)
logging.info('Compiling train loop.')
logging.flush()
dummy_batch = {
k: np.ones(v.shape, v.dtype) for k, v in train_iter.element_spec.items()
}
trainer.compile_train(dummy_batch)
# Main Loop over "epochs".
for epoch in range(first_epoch, num_epochs):
final_epoch = epoch == num_epochs - 1
logging.info('Epoch %d of %d', epoch, num_epochs)
# `stop_training` is requested, break out the main loop immediately.
if trainer.stop_training:
break
logging.info('BEGIN Train loop.')
try:
# Until the last epoch, `num_steps = steps_per_epoch`
num_steps = min(total_steps - host_step, steps_per_epoch)
epoch_end_step = host_step + num_steps
logging.info('Training for %d steps.', num_steps)
while host_step < epoch_end_step:
if trainer.stop_training:
logging.info('Saving a checkpoint before early stopping...')
checkpoint_manager.save(trainer.train_state,
checkpoint_cfg.save.state_transformation_fns)
logging.info('Stopping training loop early since `stop_training` is '
'requested.')
break
inner_num_steps = min(epoch_end_step - host_step, stats_period)
train_summary = trainer.train(
train_iter, inner_num_steps, start_step=host_step)
if not concurrent_metrics:
# Note that we always pass the dictionary of `tasks` -> summary so
# that the actions can be performed without special casing. The only
# caveat is that train would need its own special `key` given no
# `task` will be applied.
trainer.stop_training = run_actions(
trainer_lib.ActionMode.TRAIN, actions, trainer.train_state,
{TRAIN_METRIC_KEY: train_summary.result()})
host_step += inner_num_steps
logging.info('END Train loop.')
except trainer_lib.PreemptionError as e:
logging.info('Saving emergency checkpoint.')
checkpoint_manager.save(trainer.train_state,
checkpoint_cfg.save.state_transformation_fns)
logging.info('Saving emergency checkpoint done.')
raise e
step_offset = host_step - first_step
# Maybe save a checkpoint.
if checkpoint_period and (final_epoch or
step_offset % checkpoint_period == 0):
# Make sure last train step has completed before starting the clock.
train_summary.result()
logging.info('Saving checkpoint.')
checkpoint_tick = time.time()
checkpoint_manager.save(trainer.train_state,
checkpoint_cfg.save.state_transformation_fns)
checkpoint_tock = time.time()
train_metrics.write_scalar('timing/checkpoint_seconds',
checkpoint_tock - checkpoint_tick, host_step)
is_eval_epoch = eval_period and (final_epoch or
step_offset % eval_period == 0)
# Training Evaluation (i.e., with teacher forcing).
if is_eval_epoch and train_eval_datasets:
# Maybe less if final step < period.
first_run = step_offset // eval_period <= 1
_run_training_eval(first_run and not run_eval_before_training)
# Inference Evaluation (i.e., with decoding or scoring).
if evaluator is not None:
_run_inference_eval()
# Wait until computations are done before exiting
logging.info('Finished.')
trainer.close()
if evaluator:
evaluator.close()
multihost_utils.sync_global_devices('complete')
return host_step, trainer.train_state
if __name__ == '__main__':
# pylint: disable=g-import-not-at-top
from absl import app
from absl import flags
import gin
from t5x import gin_utils
# pylint: enable=g-import-not-at-top
FLAGS = flags.FLAGS
jax.config.parse_flags_with_absl()
flags.DEFINE_multi_string(
'gin_file',
default=None,
help='Path to gin configuration file. Multiple paths may be passed and '
'will be imported in the given order, with later configurations '
'overriding earlier ones.')
flags.DEFINE_multi_string(
'gin_bindings', default=[], help='Individual gin bindings.')
flags.DEFINE_list(
'gin_search_paths',
default=['.'],
help='Comma-separated list of gin config path prefixes to be prepended '
'to suffixes given via `--gin_file`. If a file appears in. Only the '
'first prefix that produces a valid path for each suffix will be '
'used.')
flags.DEFINE_string(
'tfds_data_dir', None,
'If set, this directory will be used to store datasets prepared by '
'TensorFlow Datasets that are not available in the public TFDS GCS '
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
'all `Task`s.')
flags.DEFINE_list(
'seqio_additional_cache_dirs', [],
'Directories to search for cached Tasks in addition to defaults.')
def main(argv: Sequence[str]):
"""Wrapper for pdb post mortems."""
_main(argv)
def _main(argv: Sequence[str]):
"""True main function."""
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
if FLAGS.tfds_data_dir:
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
seqio.add_global_cache_dirs(FLAGS.seqio_additional_cache_dirs)
# Create gin-configurable version of `train`.
train_using_gin = gin.configurable(train)
gin_utils.parse_gin_flags(
# User-provided gin paths take precedence if relative paths conflict.
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
FLAGS.gin_file,
FLAGS.gin_bindings)
train_using_gin()
gin_utils.run(main)