# Fine-tune a Mixture of Experts model. # # This file allows for fine-tuning with data, expert and model parallelism. To # use model parallelism, set NUM_MODEL_PARTITIONS > 1. # # # You must also include a binding for MODEL. # # Required to be set: # # - NUM_EXPERTS # - NUM_MODEL_PARTITIONS (1 if no model parallelism) # - MIXTURE_OR_TASK_NAME # - TASK_FEATURE_LENGTHS # - TRAIN_STEPS # includes pretrain steps # - MODEL_DIR # - INITIAL_CHECKPOINT_PATH # # Commonly overridden options (see also t5x/configs/runs/finetune.gin): # # - DROPOUT_RATE # - BATCH_SIZE # - Trainer.num_microbatches from __gin__ import dynamic_registration import __main__ as train_script from t5x.contrib.moe import partitioning as moe_partitioning from t5x.contrib.moe import trainer as moe_trainer from t5x import utils include 't5x/configs/runs/finetune.gin' NUM_EXPERTS = %gin.REQUIRED NUM_MODEL_PARTITIONS = %gin.REQUIRED # We use the MoE partitioner. train_script.train.partitioner = @moe_partitioning.MoePjitPartitioner() moe_partitioning.MoePjitPartitioner: num_experts = %NUM_EXPERTS num_partitions = %NUM_MODEL_PARTITIONS logical_axis_rules = @moe_partitioning.standard_logical_axis_rules() moe_partitioning.standard_logical_axis_rules: num_experts = %NUM_EXPERTS num_partitions = %NUM_MODEL_PARTITIONS # And the MoE trainer. train_script.train.trainer_cls = @moe_trainer.MoeTrainer moe_trainer.MoeTrainer: num_microbatches = None learning_rate_fn = @utils.create_learning_rate_scheduler() num_experts = %NUM_EXPERTS utils.create_learning_rate_scheduler: factors = 'constant' base_learning_rate = 0.001 warmup_steps = 1000 # Checkpoint slightly more often than fine-tuning defaults. utils.SaveCheckpointConfig.period = 2000