# Defaults for finetuning with train.py. # # # You must also include a binding for MODEL. # # Required to be set: # # - MIXTURE_OR_TASK_NAME # - TASK_FEATURE_LENGTHS # - TRAIN_STEPS # includes pretrain steps # - MODEL_DIR # automatically set when using xm_launch # - INITIAL_CHECKPOINT_PATH # # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. # # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. # # Commonly overridden options: # - DROPOUT_RATE # - BATCH_SIZE # - PjitPartitioner.num_partitions # - Trainer.num_microbatches # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess # on the fly. Most common tasks are cached, hence this is set to True by # default. from __gin__ import dynamic_registration import __main__ as train_script import seqio from t5x import gin_utils from t5x import partitioning from t5x import utils from t5x import trainer # Must be overridden MODEL_DIR = %gin.REQUIRED MIXTURE_OR_TASK_NAME = %gin.REQUIRED TASK_FEATURE_LENGTHS = %gin.REQUIRED MIXTURE_OR_TASK_MODULE = %gin.REQUIRED TRAIN_STEPS = %gin.REQUIRED INITIAL_CHECKPOINT_PATH = %gin.REQUIRED # Commonly overridden DROPOUT_RATE = 0.1 USE_CACHED_TASKS = True BATCH_SIZE = 128 # Sometimes overridden EVAL_STEPS = 20 # Convenience overrides. EVALUATOR_USE_MEMORY_CACHE = True EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. JSON_WRITE_N_RESULTS = None # Write all inferences. # HW RNG is faster than SW, but has limited determinism. # Most notably it is not deterministic across different # submeshes. USE_HARDWARE_RNG = False # None always uses faster, hardware RNG RANDOM_SEED = None # DEPRECATED: Import the this module in your gin file. MIXTURE_OR_TASK_MODULE = None train_script.train: model = %MODEL # imported from separate gin file model_dir = %MODEL_DIR train_dataset_cfg = @train/utils.DatasetConfig() train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() checkpoint_cfg = @utils.CheckpointConfig() partitioner = @partitioning.PjitPartitioner() trainer_cls = @trainer.Trainer total_steps = %TRAIN_STEPS eval_steps = %EVAL_STEPS eval_period = 1000 random_seed = %RANDOM_SEED use_hardware_rng = %USE_HARDWARE_RNG summarize_config_fn = @gin_utils.summarize_gin_config inference_evaluator_cls = @seqio.Evaluator partitioning.PjitPartitioner: num_partitions = 1 model_parallel_submesh = None logical_axis_rules = @partitioning.standard_logical_axis_rules() seqio.Evaluator: logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] num_examples = %EVALUATOR_NUM_EXAMPLES use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE seqio.JSONLogger: write_n_results = %JSON_WRITE_N_RESULTS train/utils.DatasetConfig: mixture_or_task_name = %MIXTURE_OR_TASK_NAME task_feature_lengths = %TASK_FEATURE_LENGTHS split = 'train' batch_size = %BATCH_SIZE shuffle = True seed = None # use a new seed each run/restart use_cached = %USE_CACHED_TASKS pack = True module = %MIXTURE_OR_TASK_MODULE train_eval/utils.DatasetConfig: mixture_or_task_name = %MIXTURE_OR_TASK_NAME task_feature_lengths = %TASK_FEATURE_LENGTHS split = 'validation' batch_size = %BATCH_SIZE shuffle = False seed = 42 use_cached = %USE_CACHED_TASKS pack = True module = %MIXTURE_OR_TASK_MODULE infer_eval/utils.DatasetConfig: mixture_or_task_name = %MIXTURE_OR_TASK_NAME task_feature_lengths = None # compute max split = 'validation' batch_size = %BATCH_SIZE shuffle = False seed = 42 use_cached = %USE_CACHED_TASKS pack = False module = %MIXTURE_OR_TASK_MODULE utils.CheckpointConfig: restore = @utils.RestoreCheckpointConfig() save = @utils.SaveCheckpointConfig() utils.RestoreCheckpointConfig: path = %INITIAL_CHECKPOINT_PATH mode = 'specific' dtype = 'float32' utils.SaveCheckpointConfig: period = 5000 dtype = 'float32' keep = None # keep all checkpoints save_dataset = False # don't checkpoint dataset state trainer.Trainer: num_microbatches = None learning_rate_fn = @utils.create_learning_rate_scheduler() utils.create_learning_rate_scheduler: factors = 'constant' base_learning_rate = 0.001 warmup_steps = 1000