juancopi81's picture
Add t5x and mt3 models
b100e1c
raw
history blame
No virus
1.1 kB
# Run inference with a Mixture of Experts model.
#
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - NUM_EXPERTS
# - NUM_MODEL_PARTITIONS (1 if no model parallelism)
# - MIXTURE_OR_TASK_NAME
# - TASK_FEATURE_LENGTHS
# - CHECKPOINT_PATH
# - INFER_OUTPUT_DIR
#
# Commonly overridden options (see also t5x/configs/runs/infer.gin):
#
# - DROPOUT_RATE
# - BATCH_SIZE
from __gin__ import dynamic_registration
import __main__ as infer_script
from t5x.contrib.moe import partitioning as moe_partitioning
from t5x import utils
include 't5x/configs/runs/infer.gin'
NUM_EXPERTS = %gin.REQUIRED
NUM_MODEL_PARTITIONS = %gin.REQUIRED
# We use the MoE partitioner.
infer_script.infer.partitioner = @moe_partitioning.MoePjitPartitioner()
moe_partitioning.MoePjitPartitioner:
num_experts = %NUM_EXPERTS
num_partitions = %NUM_MODEL_PARTITIONS
logical_axis_rules = @moe_partitioning.standard_logical_axis_rules()
moe_partitioning.standard_logical_axis_rules:
num_experts = %NUM_EXPERTS
num_partitions = %NUM_MODEL_PARTITIONS
utils.DatasetConfig.batch_size = %BATCH_SIZE