juancopi81's picture
Add t5x and mt3 models
b100e1c
raw
history blame
4.22 kB
# Defaults for training with train.py.
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - TASK_PREFIX
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS
# - MODEL_DIR
#
# Commonly overridden options:
# - BATCH_SIZE
# - PjitPartitioner.num_partitions
# - Trainer.num_microbatches
# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
# on the fly.
from __gin__ import dynamic_registration
import __main__ as train_script
import seqio
from mt3 import mixing
from mt3 import preprocessors
from mt3 import tasks
from mt3 import vocabularies
from t5x import gin_utils
from t5x import partitioning
from t5x import utils
from t5x import trainer
# Must be overridden
TASK_PREFIX = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
TRAIN_STEPS = %gin.REQUIRED
MODEL_DIR = %gin.REQUIRED
# Commonly overridden
TRAIN_TASK_SUFFIX = 'train'
EVAL_TASK_SUFFIX = 'eval'
USE_CACHED_TASKS = True
BATCH_SIZE = 256
# Sometimes overridden
EVAL_STEPS = 20
# Convenience overrides.
EVALUATOR_USE_MEMORY_CACHE = True
EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset.
JSON_WRITE_N_RESULTS = 0 # Don't write any inferences.
# Number of velocity bins: set to 1 (no velocity) or 127
NUM_VELOCITY_BINS = %gin.REQUIRED
VOCAB_CONFIG = @vocabularies.VocabularyConfig()
vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS
# Program granularity: set to 'flat', 'midi_class', or 'full'
PROGRAM_GRANULARITY = %gin.REQUIRED
preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY
# Maximum number of examples per mix, or None for no mixing
MAX_EXAMPLES_PER_MIX = None
mixing.mix_transcription_examples.max_examples_per_mix = %MAX_EXAMPLES_PER_MIX
train/tasks.construct_task_name:
task_prefix = %TASK_PREFIX
vocab_config = %VOCAB_CONFIG
task_suffix = %TRAIN_TASK_SUFFIX
eval/tasks.construct_task_name:
task_prefix = %TASK_PREFIX
vocab_config = %VOCAB_CONFIG
task_suffix = %EVAL_TASK_SUFFIX
train_script.train:
model = %MODEL # imported from separate gin file
model_dir = %MODEL_DIR
train_dataset_cfg = @train/utils.DatasetConfig()
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
checkpoint_cfg = @utils.CheckpointConfig()
partitioner = @partitioning.PjitPartitioner()
trainer_cls = @trainer.Trainer
total_steps = %TRAIN_STEPS
eval_steps = %EVAL_STEPS
eval_period = 5000
random_seed = None # use faster, hardware RNG
summarize_config_fn = @gin_utils.summarize_gin_config
inference_evaluator_cls = @seqio.Evaluator
seqio.Evaluator:
logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
num_examples = %EVALUATOR_NUM_EXAMPLES
use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
seqio.JSONLogger:
write_n_results = %JSON_WRITE_N_RESULTS
train/utils.DatasetConfig:
mixture_or_task_name = @train/tasks.construct_task_name()
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'train'
batch_size = %BATCH_SIZE
shuffle = True
seed = None # use a new seed each run/restart
use_cached = %USE_CACHED_TASKS
pack = False
train_eval/utils.DatasetConfig:
mixture_or_task_name = @train/tasks.construct_task_name()
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'eval'
batch_size = %BATCH_SIZE
shuffle = False
seed = 42
use_cached = %USE_CACHED_TASKS
pack = False
infer_eval/utils.DatasetConfig:
mixture_or_task_name = @eval/tasks.construct_task_name()
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'eval'
batch_size = %BATCH_SIZE
shuffle = False
seed = 42
use_cached = %USE_CACHED_TASKS
pack = False
utils.CheckpointConfig:
restore = None
save = @utils.SaveCheckpointConfig()
utils.SaveCheckpointConfig:
period = 5000
dtype = 'float32'
keep = None # keep all checkpoints
save_dataset = False # don't checkpoint dataset state
partitioning.PjitPartitioner:
num_partitions = 1
model_parallel_submesh = None
trainer.Trainer:
num_microbatches = None
learning_rate_fn = @utils.create_learning_rate_scheduler()
utils.create_learning_rate_scheduler:
factors = 'constant'
base_learning_rate = 0.001
warmup_steps = 1000