Spaces:
Build error
Build error
# Pre-train a Mixture of Experts model. | |
# | |
# This file allows for pre-training with data, expert and model parallelism. To | |
# use model parallelism, set NUM_MODEL_PARTITIONS > 1. | |
# | |
# | |
# You must also include a binding for MODEL. | |
# | |
# Required to be set: | |
# | |
# - NUM_EXPERTS | |
# - NUM_MODEL_PARTITIONS (1 if no model parallelism) | |
# - MIXTURE_OR_TASK_NAME | |
# - TASK_FEATURE_LENGTHS | |
# - TRAIN_STEPS | |
# - MODEL_DIR | |
# | |
# Commonly overridden options (see also t5x/configs/runs/pretrain.gin): | |
# | |
# - BATCH_SIZE | |
# - Trainer.num_microbatches | |
# - DROPOUT_RATE | |
from __gin__ import dynamic_registration | |
import __main__ as train_script | |
from t5x.contrib.moe import partitioning as moe_partitioning | |
from t5x.contrib.moe import trainer as moe_trainer | |
from t5x import utils | |
include 't5x/configs/runs/pretrain.gin' | |
NUM_EXPERTS = %gin.REQUIRED | |
NUM_MODEL_PARTITIONS = %gin.REQUIRED | |
# We use the MoE partitioner. | |
train_script.train.partitioner = @moe_partitioning.MoePjitPartitioner() | |
moe_partitioning.MoePjitPartitioner: | |
num_experts = %NUM_EXPERTS | |
num_partitions = %NUM_MODEL_PARTITIONS | |
logical_axis_rules = @moe_partitioning.standard_logical_axis_rules() | |
moe_partitioning.standard_logical_axis_rules: | |
num_experts = %NUM_EXPERTS | |
num_partitions = %NUM_MODEL_PARTITIONS | |
# And the MoE trainer. | |
train_script.train.trainer_cls = @moe_trainer.MoeTrainer | |
moe_trainer.MoeTrainer: | |
num_microbatches = None | |
learning_rate_fn = @utils.create_learning_rate_scheduler() | |
num_experts = %NUM_EXPERTS | |
utils.create_learning_rate_scheduler: | |
factors = 'constant * rsqrt_decay' | |
base_learning_rate = 1.0 | |
warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. | |
# Keep slightly fewer checkpoints than pre-training defaults. | |
utils.SaveCheckpointConfig.period = 5000 | |
utils.SaveCheckpointConfig.keep = 20 |