Spaces:
Build error
Build error
File size: 1,787 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# Pre-train a Mixture of Experts model.
#
# This file allows for pre-training with data, expert and model parallelism. To
# use model parallelism, set NUM_MODEL_PARTITIONS > 1.
#
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - NUM_EXPERTS
# - NUM_MODEL_PARTITIONS (1 if no model parallelism)
# - MIXTURE_OR_TASK_NAME
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS
# - MODEL_DIR
#
# Commonly overridden options (see also t5x/configs/runs/pretrain.gin):
#
# - BATCH_SIZE
# - Trainer.num_microbatches
# - DROPOUT_RATE
from __gin__ import dynamic_registration
import __main__ as train_script
from t5x.contrib.moe import partitioning as moe_partitioning
from t5x.contrib.moe import trainer as moe_trainer
from t5x import utils
include 't5x/configs/runs/pretrain.gin'
NUM_EXPERTS = %gin.REQUIRED
NUM_MODEL_PARTITIONS = %gin.REQUIRED
# We use the MoE partitioner.
train_script.train.partitioner = @moe_partitioning.MoePjitPartitioner()
moe_partitioning.MoePjitPartitioner:
num_experts = %NUM_EXPERTS
num_partitions = %NUM_MODEL_PARTITIONS
logical_axis_rules = @moe_partitioning.standard_logical_axis_rules()
moe_partitioning.standard_logical_axis_rules:
num_experts = %NUM_EXPERTS
num_partitions = %NUM_MODEL_PARTITIONS
# And the MoE trainer.
train_script.train.trainer_cls = @moe_trainer.MoeTrainer
moe_trainer.MoeTrainer:
num_microbatches = None
learning_rate_fn = @utils.create_learning_rate_scheduler()
num_experts = %NUM_EXPERTS
utils.create_learning_rate_scheduler:
factors = 'constant * rsqrt_decay'
base_learning_rate = 1.0
warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults.
# Keep slightly fewer checkpoints than pre-training defaults.
utils.SaveCheckpointConfig.period = 5000
utils.SaveCheckpointConfig.keep = 20 |