File size: 4,218 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Defaults for training with train.py.
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - TASK_PREFIX
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS
# - MODEL_DIR
#
# Commonly overridden options:
# - BATCH_SIZE
# - PjitPartitioner.num_partitions
# - Trainer.num_microbatches
# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
#    on the fly.

from __gin__ import dynamic_registration

import __main__ as train_script
import seqio
from mt3 import mixing
from mt3 import preprocessors
from mt3 import tasks
from mt3 import vocabularies
from t5x import gin_utils
from t5x import partitioning
from t5x import utils
from t5x import trainer

# Must be overridden
TASK_PREFIX = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
TRAIN_STEPS = %gin.REQUIRED
MODEL_DIR = %gin.REQUIRED

# Commonly overridden
TRAIN_TASK_SUFFIX = 'train'
EVAL_TASK_SUFFIX = 'eval'
USE_CACHED_TASKS = True
BATCH_SIZE = 256

# Sometimes overridden
EVAL_STEPS = 20

# Convenience overrides.
EVALUATOR_USE_MEMORY_CACHE = True
EVALUATOR_NUM_EXAMPLES = None  # Use all examples in the infer_eval dataset.
JSON_WRITE_N_RESULTS = 0  # Don't write any inferences.

# Number of velocity bins: set to 1 (no velocity) or 127
NUM_VELOCITY_BINS = %gin.REQUIRED
VOCAB_CONFIG = @vocabularies.VocabularyConfig()
vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS

# Program granularity: set to 'flat', 'midi_class', or 'full'
PROGRAM_GRANULARITY = %gin.REQUIRED
preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY

# Maximum number of examples per mix, or None for no mixing
MAX_EXAMPLES_PER_MIX = None
mixing.mix_transcription_examples.max_examples_per_mix = %MAX_EXAMPLES_PER_MIX

train/tasks.construct_task_name:
  task_prefix = %TASK_PREFIX
  vocab_config = %VOCAB_CONFIG
  task_suffix = %TRAIN_TASK_SUFFIX

eval/tasks.construct_task_name:
  task_prefix = %TASK_PREFIX
  vocab_config = %VOCAB_CONFIG
  task_suffix = %EVAL_TASK_SUFFIX

train_script.train:
  model = %MODEL  # imported from separate gin file
  model_dir = %MODEL_DIR
  train_dataset_cfg = @train/utils.DatasetConfig()
  train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
  infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
  checkpoint_cfg = @utils.CheckpointConfig()
  partitioner = @partitioning.PjitPartitioner()
  trainer_cls = @trainer.Trainer
  total_steps = %TRAIN_STEPS
  eval_steps = %EVAL_STEPS
  eval_period = 5000
  random_seed = None  # use faster, hardware RNG
  summarize_config_fn = @gin_utils.summarize_gin_config
  inference_evaluator_cls = @seqio.Evaluator

seqio.Evaluator:
  logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
  num_examples = %EVALUATOR_NUM_EXAMPLES
  use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE

seqio.JSONLogger:
  write_n_results = %JSON_WRITE_N_RESULTS

train/utils.DatasetConfig:
  mixture_or_task_name = @train/tasks.construct_task_name()
  task_feature_lengths = %TASK_FEATURE_LENGTHS
  split = 'train'
  batch_size = %BATCH_SIZE
  shuffle = True
  seed = None  # use a new seed each run/restart
  use_cached = %USE_CACHED_TASKS
  pack = False

train_eval/utils.DatasetConfig:
  mixture_or_task_name = @train/tasks.construct_task_name()
  task_feature_lengths = %TASK_FEATURE_LENGTHS
  split = 'eval'
  batch_size = %BATCH_SIZE
  shuffle = False
  seed = 42
  use_cached = %USE_CACHED_TASKS
  pack = False

infer_eval/utils.DatasetConfig:
  mixture_or_task_name = @eval/tasks.construct_task_name()
  task_feature_lengths = %TASK_FEATURE_LENGTHS
  split = 'eval'
  batch_size = %BATCH_SIZE
  shuffle = False
  seed = 42
  use_cached = %USE_CACHED_TASKS
  pack = False

utils.CheckpointConfig:
  restore = None
  save = @utils.SaveCheckpointConfig()
utils.SaveCheckpointConfig:
  period = 5000
  dtype = 'float32'
  keep = None  # keep all checkpoints
  save_dataset = False  # don't checkpoint dataset state

partitioning.PjitPartitioner:
  num_partitions = 1
  model_parallel_submesh = None

trainer.Trainer:
  num_microbatches = None
  learning_rate_fn = @utils.create_learning_rate_scheduler()
utils.create_learning_rate_scheduler:
  factors = 'constant'
  base_learning_rate = 0.001
  warmup_steps = 1000