# Pre-train a Mixture of Experts model. # # This file allows for pre-training with data, expert and model parallelism. To # use model parallelism, set NUM_MODEL_PARTITIONS > 1. # # # You must also include a binding for MODEL. # # Required to be set: # # - NUM_EXPERTS # - NUM_MODEL_PARTITIONS (1 if no model parallelism) # - MIXTURE_OR_TASK_NAME # - TASK_FEATURE_LENGTHS # - TRAIN_STEPS # - MODEL_DIR # # Commonly overridden options (see also t5x/configs/runs/pretrain.gin): # # - BATCH_SIZE # - Trainer.num_microbatches # - DROPOUT_RATE from __gin__ import dynamic_registration import __main__ as train_script from t5x.contrib.moe import partitioning as moe_partitioning from t5x.contrib.moe import trainer as moe_trainer from t5x import utils include 't5x/configs/runs/pretrain.gin' NUM_EXPERTS = %gin.REQUIRED NUM_MODEL_PARTITIONS = %gin.REQUIRED # We use the MoE partitioner. train_script.train.partitioner = @moe_partitioning.MoePjitPartitioner() moe_partitioning.MoePjitPartitioner: num_experts = %NUM_EXPERTS num_partitions = %NUM_MODEL_PARTITIONS logical_axis_rules = @moe_partitioning.standard_logical_axis_rules() moe_partitioning.standard_logical_axis_rules: num_experts = %NUM_EXPERTS num_partitions = %NUM_MODEL_PARTITIONS # And the MoE trainer. train_script.train.trainer_cls = @moe_trainer.MoeTrainer moe_trainer.MoeTrainer: num_microbatches = None learning_rate_fn = @utils.create_learning_rate_scheduler() num_experts = %NUM_EXPERTS utils.create_learning_rate_scheduler: factors = 'constant * rsqrt_decay' base_learning_rate = 1.0 warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. # Keep slightly fewer checkpoints than pre-training defaults. utils.SaveCheckpointConfig.period = 5000 utils.SaveCheckpointConfig.keep = 20