juancopi81's picture
Add t5x and mt3 models
b100e1c
from __gin__ import dynamic_registration
from mt3 import network
from t5x import utils
include 'train.gin'
TASK_PREFIX = 'mega_notes_ties'
TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 1024}
TRAIN_STEPS = 150000
BATCH_SIZE = 256
LABEL_SMOOTHING = 0.0
NUM_VELOCITY_BINS = 1
PROGRAM_GRANULARITY = 'full'
ONSETS_ONLY = False
USE_TIES = True
MAX_EXAMPLES_PER_MIX = None
network.T5Config.dropout_rate = 0.1
CHECKPOINT_PATH = %gin.REQUIRED
utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
utils.RestoreCheckpointConfig:
path = %CHECKPOINT_PATH
mode = 'specific'