# Defaults for training with train.py. # # You must also include a binding for MODEL. # # Required to be set: # # - TASK_PREFIX # - TASK_FEATURE_LENGTHS # - TRAIN_STEPS # - MODEL_DIR # # Commonly overridden options: # - BATCH_SIZE # - PjitPartitioner.num_partitions # - Trainer.num_microbatches # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess # on the fly. from __gin__ import dynamic_registration import __main__ as train_script import seqio from mt3 import mixing from mt3 import preprocessors from mt3 import tasks from mt3 import vocabularies from t5x import gin_utils from t5x import partitioning from t5x import utils from t5x import trainer # Must be overridden TASK_PREFIX = %gin.REQUIRED TASK_FEATURE_LENGTHS = %gin.REQUIRED TRAIN_STEPS = %gin.REQUIRED MODEL_DIR = %gin.REQUIRED # Commonly overridden TRAIN_TASK_SUFFIX = 'train' EVAL_TASK_SUFFIX = 'eval' USE_CACHED_TASKS = True BATCH_SIZE = 256 # Sometimes overridden EVAL_STEPS = 20 # Convenience overrides. EVALUATOR_USE_MEMORY_CACHE = True EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. JSON_WRITE_N_RESULTS = 0 # Don't write any inferences. # Number of velocity bins: set to 1 (no velocity) or 127 NUM_VELOCITY_BINS = %gin.REQUIRED VOCAB_CONFIG = @vocabularies.VocabularyConfig() vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS # Program granularity: set to 'flat', 'midi_class', or 'full' PROGRAM_GRANULARITY = %gin.REQUIRED preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY # Maximum number of examples per mix, or None for no mixing MAX_EXAMPLES_PER_MIX = None mixing.mix_transcription_examples.max_examples_per_mix = %MAX_EXAMPLES_PER_MIX train/tasks.construct_task_name: task_prefix = %TASK_PREFIX vocab_config = %VOCAB_CONFIG task_suffix = %TRAIN_TASK_SUFFIX eval/tasks.construct_task_name: task_prefix = %TASK_PREFIX vocab_config = %VOCAB_CONFIG task_suffix = %EVAL_TASK_SUFFIX train_script.train: model = %MODEL # imported from separate gin file model_dir = %MODEL_DIR train_dataset_cfg = @train/utils.DatasetConfig() train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() checkpoint_cfg = @utils.CheckpointConfig() partitioner = @partitioning.PjitPartitioner() trainer_cls = @trainer.Trainer total_steps = %TRAIN_STEPS eval_steps = %EVAL_STEPS eval_period = 5000 random_seed = None # use faster, hardware RNG summarize_config_fn = @gin_utils.summarize_gin_config inference_evaluator_cls = @seqio.Evaluator seqio.Evaluator: logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] num_examples = %EVALUATOR_NUM_EXAMPLES use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE seqio.JSONLogger: write_n_results = %JSON_WRITE_N_RESULTS train/utils.DatasetConfig: mixture_or_task_name = @train/tasks.construct_task_name() task_feature_lengths = %TASK_FEATURE_LENGTHS split = 'train' batch_size = %BATCH_SIZE shuffle = True seed = None # use a new seed each run/restart use_cached = %USE_CACHED_TASKS pack = False train_eval/utils.DatasetConfig: mixture_or_task_name = @train/tasks.construct_task_name() task_feature_lengths = %TASK_FEATURE_LENGTHS split = 'eval' batch_size = %BATCH_SIZE shuffle = False seed = 42 use_cached = %USE_CACHED_TASKS pack = False infer_eval/utils.DatasetConfig: mixture_or_task_name = @eval/tasks.construct_task_name() task_feature_lengths = %TASK_FEATURE_LENGTHS split = 'eval' batch_size = %BATCH_SIZE shuffle = False seed = 42 use_cached = %USE_CACHED_TASKS pack = False utils.CheckpointConfig: restore = None save = @utils.SaveCheckpointConfig() utils.SaveCheckpointConfig: period = 5000 dtype = 'float32' keep = None # keep all checkpoints save_dataset = False # don't checkpoint dataset state partitioning.PjitPartitioner: num_partitions = 1 model_parallel_submesh = None trainer.Trainer: num_microbatches = None learning_rate_fn = @utils.create_learning_rate_scheduler() utils.create_learning_rate_scheduler: factors = 'constant' base_learning_rate = 0.001 warmup_steps = 1000