|
# Defaults for pretraining with train.py. |
|
# |
|
# |
|
# You must also include a binding for MODEL. |
|
# |
|
# Required to be set |
|
# |
|
# - MIXTURE_OR_TASK_NAME |
|
# - TASK_FEATURE_LENGTHS |
|
# - TRAIN_STEPS - include pretrain steps |
|
# - MODEL_DIR: # automatically set when using xm_launch |
|
# |
|
# Commonly overridden options: |
|
# |
|
# - train/DatasetConfig.batch_size |
|
# - train_eval/DatasetConfig.batch_size |
|
# - PjitPartitioner.num_partitions |
|
# - Trainer.num_microbatches |
|
# - DROPOUT_RATE |
|
from __gin__ import dynamic_registration |
|
|
|
import __main__ as train_script |
|
from t5x import gin_utils |
|
from t5x import partitioning |
|
from t5x import utils |
|
from t5x import trainer |
|
|
|
MIXTURE_OR_TASK_NAME = %gin.REQUIRED |
|
TASK_FEATURE_LENGTHS = %gin.REQUIRED |
|
TRAIN_STEPS = %gin.REQUIRED |
|
MODEL_DIR = %gin.REQUIRED |
|
BATCH_SIZE = 128 |
|
USE_CACHED_TASKS = True |
|
INITIAL_CHECKPOINT_PATH = %gin.REQUIRED |
|
|
|
# DEPRECATED: Import the this module in your gin file. |
|
MIXTURE_OR_TASK_MODULE = None |
|
SHUFFLE_TRAIN_EXAMPLES = True |
|
|
|
# HW RNG is faster than SW, but has limited determinism. |
|
# Most notably it is not deterministic across different |
|
# submeshes. |
|
USE_HARDWARE_RNG = False |
|
# None always uses faster, hardware RNG |
|
RANDOM_SEED = None |
|
|
|
# Can be overridden with `train.*`.` |
|
train_script.train: |
|
model = %MODEL # imported from separate gin file |
|
model_dir = %MODEL_DIR |
|
train_dataset_cfg = @train/utils.DatasetConfig() |
|
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() |
|
infer_eval_dataset_cfg = None |
|
checkpoint_cfg = @utils.CheckpointConfig() |
|
partitioner = @partitioning.PjitPartitioner() |
|
trainer_cls = @trainer.Trainer |
|
total_steps = %TRAIN_STEPS |
|
eval_steps = 20 |
|
eval_period = 1000 |
|
random_seed = %RANDOM_SEED |
|
use_hardware_rng = %USE_HARDWARE_RNG |
|
summarize_config_fn = @gin_utils.summarize_gin_config |
|
|
|
partitioning.PjitPartitioner: |
|
num_partitions = 1 |
|
model_parallel_submesh = None |
|
logical_axis_rules = @partitioning.standard_logical_axis_rules() |
|
|
|
train/utils.DatasetConfig: |
|
mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
|
task_feature_lengths = %TASK_FEATURE_LENGTHS |
|
split = 'train' |
|
batch_size = %BATCH_SIZE |
|
shuffle = %SHUFFLE_TRAIN_EXAMPLES |
|
seed = None # use a new seed each run/restart |
|
use_cached = %USE_CACHED_TASKS |
|
pack = True |
|
module = %MIXTURE_OR_TASK_MODULE |
|
|
|
train_eval/utils.DatasetConfig: |
|
mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
|
task_feature_lengths = %TASK_FEATURE_LENGTHS |
|
split = 'validation' |
|
batch_size = %BATCH_SIZE |
|
shuffle = False |
|
seed = 42 |
|
use_cached = %USE_CACHED_TASKS |
|
pack = True |
|
module = %MIXTURE_OR_TASK_MODULE |
|
|
|
utils.CheckpointConfig: |
|
restore = @utils.RestoreCheckpointConfig() |
|
save = @utils.SaveCheckpointConfig() |
|
utils.RestoreCheckpointConfig: |
|
path = %INITIAL_CHECKPOINT_PATH |
|
mode = 'specific' |
|
dtype = 'float32' |
|
utils.SaveCheckpointConfig: |
|
period = 1000 |
|
dtype = 'float32' |
|
keep = None # keep all checkpoints |
|
save_dataset = False # don't checkpoint dataset state |
|
|
|
trainer.Trainer: |
|
num_microbatches = None |
|
learning_rate_fn = @utils.create_learning_rate_scheduler() |
|
|
|
utils.create_learning_rate_scheduler: |
|
factors = 'constant * rsqrt_decay' |
|
base_learning_rate = 0.5 #This is set to half of the original since it is continued training |
|
warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. |
|
|