File size: 4,378 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Defaults for finetuning with train.py.
#
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - MIXTURE_OR_TASK_NAME
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS  # includes pretrain steps
# - MODEL_DIR  # automatically set when using xm_launch
# - INITIAL_CHECKPOINT_PATH
#
# When running locally, it needs to be passed in the `gin.MODEL_DIR` flag.
#
# `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt
# has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps.
#
# Commonly overridden options:
# - DROPOUT_RATE
# - BATCH_SIZE
# - PjitPartitioner.num_partitions
# - Trainer.num_microbatches
# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
#    on the fly. Most common tasks are cached, hence this is set to True by
#    default.

from __gin__ import dynamic_registration

import __main__ as train_script
import seqio
from t5x import gin_utils
from t5x import partitioning
from t5x import utils
from t5x import trainer

# Must be overridden
MODEL_DIR = %gin.REQUIRED
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
MIXTURE_OR_TASK_MODULE = %gin.REQUIRED
TRAIN_STEPS = %gin.REQUIRED
INITIAL_CHECKPOINT_PATH = %gin.REQUIRED

# Commonly overridden
DROPOUT_RATE = 0.1
USE_CACHED_TASKS = True
BATCH_SIZE = 128

# Sometimes overridden
EVAL_STEPS = 20

# Convenience overrides.
EVALUATOR_USE_MEMORY_CACHE = True
EVALUATOR_NUM_EXAMPLES = None  # Use all examples in the infer_eval dataset.
JSON_WRITE_N_RESULTS = None  # Write all inferences.
# HW RNG is faster than SW, but has limited determinism.
# Most notably it is not deterministic across different
# submeshes.
USE_HARDWARE_RNG = False
# None always uses faster, hardware RNG
RANDOM_SEED = None

# DEPRECATED: Import the this module in your gin file.
MIXTURE_OR_TASK_MODULE = None

train_script.train:
  model = %MODEL  # imported from separate gin file
  model_dir = %MODEL_DIR
  train_dataset_cfg = @train/utils.DatasetConfig()
  train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
  infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
  checkpoint_cfg = @utils.CheckpointConfig()
  partitioner = @partitioning.PjitPartitioner()
  trainer_cls = @trainer.Trainer
  total_steps = %TRAIN_STEPS
  eval_steps = %EVAL_STEPS
  eval_period = 1000
  random_seed = %RANDOM_SEED
  use_hardware_rng = %USE_HARDWARE_RNG
  summarize_config_fn = @gin_utils.summarize_gin_config
  inference_evaluator_cls = @seqio.Evaluator

partitioning.PjitPartitioner:
  num_partitions = 1
  model_parallel_submesh = None
  logical_axis_rules = @partitioning.standard_logical_axis_rules()

seqio.Evaluator:
  logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
  num_examples = %EVALUATOR_NUM_EXAMPLES
  use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE

seqio.JSONLogger:
  write_n_results = %JSON_WRITE_N_RESULTS

train/utils.DatasetConfig:
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
  task_feature_lengths = %TASK_FEATURE_LENGTHS
  split = 'train'
  batch_size = %BATCH_SIZE
  shuffle = True
  seed = None  # use a new seed each run/restart
  use_cached = %USE_CACHED_TASKS
  pack = True
  module = %MIXTURE_OR_TASK_MODULE

train_eval/utils.DatasetConfig:
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
  task_feature_lengths = %TASK_FEATURE_LENGTHS
  split = 'validation'
  batch_size = %BATCH_SIZE
  shuffle = False
  seed = 42
  use_cached = %USE_CACHED_TASKS
  pack = True
  module = %MIXTURE_OR_TASK_MODULE

infer_eval/utils.DatasetConfig:
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
  task_feature_lengths = None  # compute max
  split = 'validation'
  batch_size = %BATCH_SIZE
  shuffle = False
  seed = 42
  use_cached = %USE_CACHED_TASKS
  pack = False
  module = %MIXTURE_OR_TASK_MODULE

utils.CheckpointConfig:
  restore = @utils.RestoreCheckpointConfig()
  save = @utils.SaveCheckpointConfig()
utils.RestoreCheckpointConfig:
  path = %INITIAL_CHECKPOINT_PATH
  mode = 'specific'
  dtype = 'float32'
utils.SaveCheckpointConfig:
  period = 5000
  dtype = 'float32'
  keep = None  # keep all checkpoints
  save_dataset = False  # don't checkpoint dataset state

trainer.Trainer:
  num_microbatches = None
  learning_rate_fn = @utils.create_learning_rate_scheduler()
utils.create_learning_rate_scheduler:
  factors = 'constant'
  base_learning_rate = 0.001
  warmup_steps = 1000