File size: 3,090 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Defaults for pretraining with train.py.
#
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - MIXTURE_OR_TASK_NAME
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS
# - MODEL_DIR: # automatically set when using xm_launch
#
# Commonly overridden options:
#
# - train/DatasetConfig.batch_size
# - train_eval/DatasetConfig.batch_size
# - PjitPartitioner.num_partitions
# - Trainer.num_microbatches
# - DROPOUT_RATE
from __gin__ import dynamic_registration

import __main__ as train_script
from t5x import gin_utils
from t5x import partitioning
from t5x import utils
from t5x import trainer

MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
TRAIN_STEPS = %gin.REQUIRED
MODEL_DIR = %gin.REQUIRED
BATCH_SIZE = 128
USE_CACHED_TASKS = True

# DEPRECATED: Import the this module in your gin file.
MIXTURE_OR_TASK_MODULE = None
SHUFFLE_TRAIN_EXAMPLES = True

# HW RNG is faster than SW, but has limited determinism.
# Most notably it is not deterministic across different
# submeshes.
USE_HARDWARE_RNG = False
# None always uses faster, hardware RNG
RANDOM_SEED = None

# Can be overridden with `train.*`.`
train_script.train:
  model = %MODEL  # imported from separate gin file
  model_dir = %MODEL_DIR
  train_dataset_cfg = @train/utils.DatasetConfig()
  train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
  infer_eval_dataset_cfg = None
  checkpoint_cfg = @utils.CheckpointConfig()
  partitioner = @partitioning.PjitPartitioner()
  trainer_cls = @trainer.Trainer
  total_steps = %TRAIN_STEPS
  eval_steps = 20
  eval_period = 1000
  random_seed = %RANDOM_SEED
  use_hardware_rng = %USE_HARDWARE_RNG
  summarize_config_fn = @gin_utils.summarize_gin_config

partitioning.PjitPartitioner:
  num_partitions = 1
  model_parallel_submesh = None
  logical_axis_rules = @partitioning.standard_logical_axis_rules()

train/utils.DatasetConfig:
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
  task_feature_lengths = %TASK_FEATURE_LENGTHS
  split = 'train'
  batch_size = %BATCH_SIZE
  shuffle = %SHUFFLE_TRAIN_EXAMPLES
  seed = None  # use a new seed each run/restart
  use_cached = %USE_CACHED_TASKS
  pack = True
  module = %MIXTURE_OR_TASK_MODULE

train_eval/utils.DatasetConfig:
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
  task_feature_lengths = %TASK_FEATURE_LENGTHS
  split = 'validation'
  batch_size = %BATCH_SIZE
  shuffle = False
  seed = 42
  use_cached = %USE_CACHED_TASKS
  pack = True
  module = %MIXTURE_OR_TASK_MODULE

utils.CheckpointConfig:
  restore = @utils.RestoreCheckpointConfig()
  save = @utils.SaveCheckpointConfig()
utils.RestoreCheckpointConfig:
  path = []  # initialize from scratch
utils.SaveCheckpointConfig:
  period = 1000
  dtype = 'float32'
  keep = None  # keep all checkpoints
  save_dataset = False  # don't checkpoint dataset state

trainer.Trainer:
  num_microbatches = None
  learning_rate_fn = @utils.create_learning_rate_scheduler()

utils.create_learning_rate_scheduler:
  factors = 'constant * rsqrt_decay'
  base_learning_rate = 1.0
  warmup_steps = 10000  # 10k to keep consistent with T5/MTF defaults.