# Defaults for infer.py. # # You must also include a binding for MODEL. # # Required to be set: # # - TASK_PREFIX # - TASK_FEATURE_LENGTHS # - CHECKPOINT_PATH # - INFER_OUTPUT_DIR # # Commonly overridden options: # # - infer.mode # - infer.checkpoint_period # - infer.shard_id # - infer.num_shards # - DatasetConfig.split # - DatasetConfig.batch_size # - DatasetConfig.use_cached # - RestoreCheckpointConfig.is_tensorflow # - RestoreCheckpointConfig.mode # - PjitPartitioner.num_partitions from __gin__ import dynamic_registration import __main__ as infer_script from mt3 import inference from mt3 import preprocessors from mt3 import tasks from mt3 import vocabularies from t5x import partitioning from t5x import utils # Must be overridden TASK_PREFIX = %gin.REQUIRED TASK_FEATURE_LENGTHS = %gin.REQUIRED CHECKPOINT_PATH = %gin.REQUIRED INFER_OUTPUT_DIR = %gin.REQUIRED # Number of velocity bins: set to 1 (no velocity) or 127 NUM_VELOCITY_BINS = %gin.REQUIRED VOCAB_CONFIG = @vocabularies.VocabularyConfig() vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS # Program granularity: set to 'flat', 'midi_class', or 'full' PROGRAM_GRANULARITY = %gin.REQUIRED preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY TASK_SUFFIX = 'test' tasks.construct_task_name: task_prefix = %TASK_PREFIX vocab_config = %VOCAB_CONFIG task_suffix = %TASK_SUFFIX ONSETS_ONLY = %gin.REQUIRED USE_TIES = %gin.REQUIRED inference.write_inferences_to_file: vocab_config = %VOCAB_CONFIG onsets_only = %ONSETS_ONLY use_ties = %USE_TIES infer_script.infer: mode = 'predict' model = %MODEL # imported from separate gin file output_dir = %INFER_OUTPUT_DIR dataset_cfg = @utils.DatasetConfig() partitioner = @partitioning.PjitPartitioner() restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() # This is a hack, but pass an extremely large value here to make sure the # entire dataset fits in a single epoch. Otherwise, segments from a single # example may end up in different epochs after splitting. checkpoint_period = 1000000 shard_id = 0 num_shards = 1 write_fn = @inference.write_inferences_to_file utils.DatasetConfig: mixture_or_task_name = @tasks.construct_task_name() task_feature_lengths = %TASK_FEATURE_LENGTHS use_cached = True split = 'eval' batch_size = 32 shuffle = False seed = 0 pack = False partitioning.PjitPartitioner.num_partitions = 1 utils.RestoreCheckpointConfig: path = %CHECKPOINT_PATH mode = 'specific'