Spaces:
Build error
Build error
File size: 4,218 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# Defaults for training with train.py.
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - TASK_PREFIX
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS
# - MODEL_DIR
#
# Commonly overridden options:
# - BATCH_SIZE
# - PjitPartitioner.num_partitions
# - Trainer.num_microbatches
# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
# on the fly.
from __gin__ import dynamic_registration
import __main__ as train_script
import seqio
from mt3 import mixing
from mt3 import preprocessors
from mt3 import tasks
from mt3 import vocabularies
from t5x import gin_utils
from t5x import partitioning
from t5x import utils
from t5x import trainer
# Must be overridden
TASK_PREFIX = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
TRAIN_STEPS = %gin.REQUIRED
MODEL_DIR = %gin.REQUIRED
# Commonly overridden
TRAIN_TASK_SUFFIX = 'train'
EVAL_TASK_SUFFIX = 'eval'
USE_CACHED_TASKS = True
BATCH_SIZE = 256
# Sometimes overridden
EVAL_STEPS = 20
# Convenience overrides.
EVALUATOR_USE_MEMORY_CACHE = True
EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset.
JSON_WRITE_N_RESULTS = 0 # Don't write any inferences.
# Number of velocity bins: set to 1 (no velocity) or 127
NUM_VELOCITY_BINS = %gin.REQUIRED
VOCAB_CONFIG = @vocabularies.VocabularyConfig()
vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS
# Program granularity: set to 'flat', 'midi_class', or 'full'
PROGRAM_GRANULARITY = %gin.REQUIRED
preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY
# Maximum number of examples per mix, or None for no mixing
MAX_EXAMPLES_PER_MIX = None
mixing.mix_transcription_examples.max_examples_per_mix = %MAX_EXAMPLES_PER_MIX
train/tasks.construct_task_name:
task_prefix = %TASK_PREFIX
vocab_config = %VOCAB_CONFIG
task_suffix = %TRAIN_TASK_SUFFIX
eval/tasks.construct_task_name:
task_prefix = %TASK_PREFIX
vocab_config = %VOCAB_CONFIG
task_suffix = %EVAL_TASK_SUFFIX
train_script.train:
model = %MODEL # imported from separate gin file
model_dir = %MODEL_DIR
train_dataset_cfg = @train/utils.DatasetConfig()
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
checkpoint_cfg = @utils.CheckpointConfig()
partitioner = @partitioning.PjitPartitioner()
trainer_cls = @trainer.Trainer
total_steps = %TRAIN_STEPS
eval_steps = %EVAL_STEPS
eval_period = 5000
random_seed = None # use faster, hardware RNG
summarize_config_fn = @gin_utils.summarize_gin_config
inference_evaluator_cls = @seqio.Evaluator
seqio.Evaluator:
logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
num_examples = %EVALUATOR_NUM_EXAMPLES
use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
seqio.JSONLogger:
write_n_results = %JSON_WRITE_N_RESULTS
train/utils.DatasetConfig:
mixture_or_task_name = @train/tasks.construct_task_name()
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'train'
batch_size = %BATCH_SIZE
shuffle = True
seed = None # use a new seed each run/restart
use_cached = %USE_CACHED_TASKS
pack = False
train_eval/utils.DatasetConfig:
mixture_or_task_name = @train/tasks.construct_task_name()
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'eval'
batch_size = %BATCH_SIZE
shuffle = False
seed = 42
use_cached = %USE_CACHED_TASKS
pack = False
infer_eval/utils.DatasetConfig:
mixture_or_task_name = @eval/tasks.construct_task_name()
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'eval'
batch_size = %BATCH_SIZE
shuffle = False
seed = 42
use_cached = %USE_CACHED_TASKS
pack = False
utils.CheckpointConfig:
restore = None
save = @utils.SaveCheckpointConfig()
utils.SaveCheckpointConfig:
period = 5000
dtype = 'float32'
keep = None # keep all checkpoints
save_dataset = False # don't checkpoint dataset state
partitioning.PjitPartitioner:
num_partitions = 1
model_parallel_submesh = None
trainer.Trainer:
num_microbatches = None
learning_rate_fn = @utils.create_learning_rate_scheduler()
utils.create_learning_rate_scheduler:
factors = 'constant'
base_learning_rate = 0.001
warmup_steps = 1000
|