File size: 1,742 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Fine-tune a Mixture of Experts model.
#
# This file allows for fine-tuning with data, expert and model parallelism. To
# use model parallelism, set NUM_MODEL_PARTITIONS > 1.
#
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - NUM_EXPERTS
# - NUM_MODEL_PARTITIONS  (1 if no model parallelism)
# - MIXTURE_OR_TASK_NAME
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS  # includes pretrain steps
# - MODEL_DIR
# - INITIAL_CHECKPOINT_PATH
#
# Commonly overridden options (see also t5x/configs/runs/finetune.gin):
#
# - DROPOUT_RATE
# - BATCH_SIZE
# - Trainer.num_microbatches

from __gin__ import dynamic_registration

import __main__ as train_script

from t5x.contrib.moe import partitioning as moe_partitioning
from t5x.contrib.moe import trainer as moe_trainer
from t5x import utils

include 't5x/configs/runs/finetune.gin'

NUM_EXPERTS = %gin.REQUIRED
NUM_MODEL_PARTITIONS = %gin.REQUIRED

# We use the MoE partitioner.
train_script.train.partitioner = @moe_partitioning.MoePjitPartitioner()
moe_partitioning.MoePjitPartitioner:
  num_experts = %NUM_EXPERTS
  num_partitions = %NUM_MODEL_PARTITIONS
  logical_axis_rules = @moe_partitioning.standard_logical_axis_rules()
moe_partitioning.standard_logical_axis_rules:
  num_experts = %NUM_EXPERTS
  num_partitions = %NUM_MODEL_PARTITIONS

# And the MoE trainer.
train_script.train.trainer_cls = @moe_trainer.MoeTrainer
moe_trainer.MoeTrainer:
  num_microbatches = None
  learning_rate_fn = @utils.create_learning_rate_scheduler()
  num_experts = %NUM_EXPERTS
utils.create_learning_rate_scheduler:
  factors = 'constant'
  base_learning_rate = 0.001
  warmup_steps = 1000

# Checkpoint slightly more often than fine-tuning defaults.
utils.SaveCheckpointConfig.period = 2000