juancopi81's picture
Add t5x and mt3 models
b100e1c
# Defaults for pretraining with train.py.
#
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - MIXTURE_OR_TASK_NAME
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS
# - MODEL_DIR: # automatically set when using xm_launch
#
# Commonly overridden options:
#
# - train/DatasetConfig.batch_size
# - train_eval/DatasetConfig.batch_size
# - PjitPartitioner.num_partitions
# - Trainer.num_microbatches
# - DROPOUT_RATE
from __gin__ import dynamic_registration
import __main__ as train_script
from t5x import gin_utils
from t5x import partitioning
from t5x import utils
from t5x import trainer
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
TRAIN_STEPS = %gin.REQUIRED
MODEL_DIR = %gin.REQUIRED
BATCH_SIZE = 128
USE_CACHED_TASKS = True
# DEPRECATED: Import the this module in your gin file.
MIXTURE_OR_TASK_MODULE = None
SHUFFLE_TRAIN_EXAMPLES = True
# HW RNG is faster than SW, but has limited determinism.
# Most notably it is not deterministic across different
# submeshes.
USE_HARDWARE_RNG = False
# None always uses faster, hardware RNG
RANDOM_SEED = None
# Can be overridden with `train.*`.`
train_script.train:
model = %MODEL # imported from separate gin file
model_dir = %MODEL_DIR
train_dataset_cfg = @train/utils.DatasetConfig()
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
infer_eval_dataset_cfg = None
checkpoint_cfg = @utils.CheckpointConfig()
partitioner = @partitioning.PjitPartitioner()
trainer_cls = @trainer.Trainer
total_steps = %TRAIN_STEPS
eval_steps = 20
eval_period = 1000
random_seed = %RANDOM_SEED
use_hardware_rng = %USE_HARDWARE_RNG
summarize_config_fn = @gin_utils.summarize_gin_config
partitioning.PjitPartitioner:
num_partitions = 1
model_parallel_submesh = None
logical_axis_rules = @partitioning.standard_logical_axis_rules()
train/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'train'
batch_size = %BATCH_SIZE
shuffle = %SHUFFLE_TRAIN_EXAMPLES
seed = None # use a new seed each run/restart
use_cached = %USE_CACHED_TASKS
pack = True
module = %MIXTURE_OR_TASK_MODULE
train_eval/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'validation'
batch_size = %BATCH_SIZE
shuffle = False
seed = 42
use_cached = %USE_CACHED_TASKS
pack = True
module = %MIXTURE_OR_TASK_MODULE
utils.CheckpointConfig:
restore = @utils.RestoreCheckpointConfig()
save = @utils.SaveCheckpointConfig()
utils.RestoreCheckpointConfig:
path = [] # initialize from scratch
utils.SaveCheckpointConfig:
period = 1000
dtype = 'float32'
keep = None # keep all checkpoints
save_dataset = False # don't checkpoint dataset state
trainer.Trainer:
num_microbatches = None
learning_rate_fn = @utils.create_learning_rate_scheduler()
utils.create_learning_rate_scheduler:
factors = 'constant * rsqrt_decay'
base_learning_rate = 1.0
warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults.