# Defaults for pretraining with train.py. # # # You must also include a binding for MODEL. # # Required to be set # # - MIXTURE_OR_TASK_NAME # - TASK_FEATURE_LENGTHS # - TRAIN_STEPS - include pretrain steps # - MODEL_DIR: # automatically set when using xm_launch # # Commonly overridden options: # # - train/DatasetConfig.batch_size # - train_eval/DatasetConfig.batch_size # - PjitPartitioner.num_partitions # - Trainer.num_microbatches # - DROPOUT_RATE from __gin__ import dynamic_registration import __main__ as train_script from t5x import gin_utils from t5x import partitioning from t5x import utils from t5x import trainer MIXTURE_OR_TASK_NAME = %gin.REQUIRED TASK_FEATURE_LENGTHS = %gin.REQUIRED TRAIN_STEPS = %gin.REQUIRED MODEL_DIR = %gin.REQUIRED BATCH_SIZE = 128 USE_CACHED_TASKS = True INITIAL_CHECKPOINT_PATH = %gin.REQUIRED # DEPRECATED: Import the this module in your gin file. MIXTURE_OR_TASK_MODULE = None SHUFFLE_TRAIN_EXAMPLES = True # HW RNG is faster than SW, but has limited determinism. # Most notably it is not deterministic across different # submeshes. USE_HARDWARE_RNG = False # None always uses faster, hardware RNG RANDOM_SEED = None # Can be overridden with `train.*`.` train_script.train: model = %MODEL # imported from separate gin file model_dir = %MODEL_DIR train_dataset_cfg = @train/utils.DatasetConfig() train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() infer_eval_dataset_cfg = None checkpoint_cfg = @utils.CheckpointConfig() partitioner = @partitioning.PjitPartitioner() trainer_cls = @trainer.Trainer total_steps = %TRAIN_STEPS eval_steps = 20 eval_period = 1000 random_seed = %RANDOM_SEED use_hardware_rng = %USE_HARDWARE_RNG summarize_config_fn = @gin_utils.summarize_gin_config partitioning.PjitPartitioner: num_partitions = 1 model_parallel_submesh = None logical_axis_rules = @partitioning.standard_logical_axis_rules() train/utils.DatasetConfig: mixture_or_task_name = %MIXTURE_OR_TASK_NAME task_feature_lengths = %TASK_FEATURE_LENGTHS split = 'train' batch_size = %BATCH_SIZE shuffle = %SHUFFLE_TRAIN_EXAMPLES seed = None # use a new seed each run/restart use_cached = %USE_CACHED_TASKS pack = True module = %MIXTURE_OR_TASK_MODULE train_eval/utils.DatasetConfig: mixture_or_task_name = %MIXTURE_OR_TASK_NAME task_feature_lengths = %TASK_FEATURE_LENGTHS split = 'validation' batch_size = %BATCH_SIZE shuffle = False seed = 42 use_cached = %USE_CACHED_TASKS pack = True module = %MIXTURE_OR_TASK_MODULE utils.CheckpointConfig: restore = @utils.RestoreCheckpointConfig() save = @utils.SaveCheckpointConfig() utils.RestoreCheckpointConfig: path = %INITIAL_CHECKPOINT_PATH mode = 'specific' dtype = 'float32' utils.SaveCheckpointConfig: period = 1000 dtype = 'float32' keep = None # keep all checkpoints save_dataset = False # don't checkpoint dataset state trainer.Trainer: num_microbatches = None learning_rate_fn = @utils.create_learning_rate_scheduler() utils.create_learning_rate_scheduler: factors = 'constant * rsqrt_decay' base_learning_rate = 0.5 #This is set to half of the original since it is continued training warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults.