# Run inference with a Mixture of Experts model. # # # You must also include a binding for MODEL. # # Required to be set: # # - NUM_EXPERTS # - NUM_MODEL_PARTITIONS (1 if no model parallelism) # - MIXTURE_OR_TASK_NAME # - TASK_FEATURE_LENGTHS # - CHECKPOINT_PATH # - INFER_OUTPUT_DIR # # Commonly overridden options (see also t5x/configs/runs/infer.gin): # # - DROPOUT_RATE # - BATCH_SIZE from __gin__ import dynamic_registration import __main__ as infer_script from t5x.contrib.moe import partitioning as moe_partitioning from t5x import utils include 't5x/configs/runs/infer.gin' NUM_EXPERTS = %gin.REQUIRED NUM_MODEL_PARTITIONS = %gin.REQUIRED # We use the MoE partitioner. infer_script.infer.partitioner = @moe_partitioning.MoePjitPartitioner() moe_partitioning.MoePjitPartitioner: num_experts = %NUM_EXPERTS num_partitions = %NUM_MODEL_PARTITIONS logical_axis_rules = @moe_partitioning.standard_logical_axis_rules() moe_partitioning.standard_logical_axis_rules: num_experts = %NUM_EXPERTS num_partitions = %NUM_MODEL_PARTITIONS utils.DatasetConfig.batch_size = %BATCH_SIZE