Spaces:
Build error
Build error
# Defaults for finetuning with train.py. | |
# | |
# | |
# You must also include a binding for MODEL. | |
# | |
# Required to be set: | |
# | |
# - MIXTURE_OR_TASK_NAME | |
# - TASK_FEATURE_LENGTHS | |
# - TRAIN_STEPS # includes pretrain steps | |
# - MODEL_DIR # automatically set when using xm_launch | |
# - INITIAL_CHECKPOINT_PATH | |
# | |
# When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. | |
# | |
# `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt | |
# has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. | |
# | |
# Commonly overridden options: | |
# - DROPOUT_RATE | |
# - BATCH_SIZE | |
# - PjitPartitioner.num_partitions | |
# - Trainer.num_microbatches | |
# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess | |
# on the fly. Most common tasks are cached, hence this is set to True by | |
# default. | |
from __gin__ import dynamic_registration | |
import __main__ as train_script | |
import seqio | |
from t5x import gin_utils | |
from t5x import partitioning | |
from t5x import utils | |
from t5x import trainer | |
# Must be overridden | |
MODEL_DIR = %gin.REQUIRED | |
MIXTURE_OR_TASK_NAME = %gin.REQUIRED | |
TASK_FEATURE_LENGTHS = %gin.REQUIRED | |
MIXTURE_OR_TASK_MODULE = %gin.REQUIRED | |
TRAIN_STEPS = %gin.REQUIRED | |
INITIAL_CHECKPOINT_PATH = %gin.REQUIRED | |
# Commonly overridden | |
DROPOUT_RATE = 0.1 | |
USE_CACHED_TASKS = True | |
BATCH_SIZE = 128 | |
# Sometimes overridden | |
EVAL_STEPS = 20 | |
# Convenience overrides. | |
EVALUATOR_USE_MEMORY_CACHE = True | |
EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. | |
JSON_WRITE_N_RESULTS = None # Write all inferences. | |
# HW RNG is faster than SW, but has limited determinism. | |
# Most notably it is not deterministic across different | |
# submeshes. | |
USE_HARDWARE_RNG = False | |
# None always uses faster, hardware RNG | |
RANDOM_SEED = None | |
# DEPRECATED: Import the this module in your gin file. | |
MIXTURE_OR_TASK_MODULE = None | |
train_script.train: | |
model = %MODEL # imported from separate gin file | |
model_dir = %MODEL_DIR | |
train_dataset_cfg = @train/utils.DatasetConfig() | |
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() | |
infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() | |
checkpoint_cfg = @utils.CheckpointConfig() | |
partitioner = @partitioning.PjitPartitioner() | |
trainer_cls = @trainer.Trainer | |
total_steps = %TRAIN_STEPS | |
eval_steps = %EVAL_STEPS | |
eval_period = 1000 | |
random_seed = %RANDOM_SEED | |
use_hardware_rng = %USE_HARDWARE_RNG | |
summarize_config_fn = @gin_utils.summarize_gin_config | |
inference_evaluator_cls = @seqio.Evaluator | |
partitioning.PjitPartitioner: | |
num_partitions = 1 | |
model_parallel_submesh = None | |
logical_axis_rules = @partitioning.standard_logical_axis_rules() | |
seqio.Evaluator: | |
logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] | |
num_examples = %EVALUATOR_NUM_EXAMPLES | |
use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE | |
seqio.JSONLogger: | |
write_n_results = %JSON_WRITE_N_RESULTS | |
train/utils.DatasetConfig: | |
mixture_or_task_name = %MIXTURE_OR_TASK_NAME | |
task_feature_lengths = %TASK_FEATURE_LENGTHS | |
split = 'train' | |
batch_size = %BATCH_SIZE | |
shuffle = True | |
seed = None # use a new seed each run/restart | |
use_cached = %USE_CACHED_TASKS | |
pack = True | |
module = %MIXTURE_OR_TASK_MODULE | |
train_eval/utils.DatasetConfig: | |
mixture_or_task_name = %MIXTURE_OR_TASK_NAME | |
task_feature_lengths = %TASK_FEATURE_LENGTHS | |
split = 'validation' | |
batch_size = %BATCH_SIZE | |
shuffle = False | |
seed = 42 | |
use_cached = %USE_CACHED_TASKS | |
pack = True | |
module = %MIXTURE_OR_TASK_MODULE | |
infer_eval/utils.DatasetConfig: | |
mixture_or_task_name = %MIXTURE_OR_TASK_NAME | |
task_feature_lengths = None # compute max | |
split = 'validation' | |
batch_size = %BATCH_SIZE | |
shuffle = False | |
seed = 42 | |
use_cached = %USE_CACHED_TASKS | |
pack = False | |
module = %MIXTURE_OR_TASK_MODULE | |
utils.CheckpointConfig: | |
restore = @utils.RestoreCheckpointConfig() | |
save = @utils.SaveCheckpointConfig() | |
utils.RestoreCheckpointConfig: | |
path = %INITIAL_CHECKPOINT_PATH | |
mode = 'specific' | |
dtype = 'float32' | |
utils.SaveCheckpointConfig: | |
period = 5000 | |
dtype = 'float32' | |
keep = None # keep all checkpoints | |
save_dataset = False # don't checkpoint dataset state | |
trainer.Trainer: | |
num_microbatches = None | |
learning_rate_fn = @utils.create_learning_rate_scheduler() | |
utils.create_learning_rate_scheduler: | |
factors = 'constant' | |
base_learning_rate = 0.001 | |
warmup_steps = 1000 | |