juancopi81's picture
Add t5x and mt3 models
b100e1c
raw
history blame
2.09 kB
# T5.1.1 Small model.
from __gin__ import dynamic_registration
from mt3 import models
from mt3 import network
from mt3 import spectrograms
from mt3 import vocabularies
import seqio
from t5x import adafactor
# ------------------- Loss HParam ----------------------------------------------
Z_LOSS = 0.0001
LABEL_SMOOTHING = 0.0
LOSS_NORMALIZING_FACTOR = None
models.ContinuousInputsEncoderDecoderModel:
z_loss = %Z_LOSS
label_smoothing = %LABEL_SMOOTHING
loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
# Output vocabulary
VOCAB_CONFIG = %gin.REQUIRED
OUTPUT_VOCABULARY = @vocabularies.vocabulary_from_codec()
vocabularies.vocabulary_from_codec.codec = @vocabularies.build_codec()
vocabularies.build_codec.vocab_config = %VOCAB_CONFIG
# ------------------- Optimizer ------------------------------------------------
# `learning_rate` is set by `Trainer.learning_rate_fn`.
OPTIMIZER = @adafactor.Adafactor()
adafactor.Adafactor:
decay_rate = 0.8
step_offset = 0
logical_factor_rules = @adafactor.standard_logical_factor_rules()
# ------------------- Model ----------------------------------------------------
SPECTROGRAM_CONFIG = @spectrograms.SpectrogramConfig()
MODEL = @models.ContinuousInputsEncoderDecoderModel()
models.ContinuousInputsEncoderDecoderModel:
module = @network.Transformer()
input_vocabulary = @seqio.vocabularies.PassThroughVocabulary()
output_vocabulary = %OUTPUT_VOCABULARY
optimizer_def = %OPTIMIZER
input_depth = @spectrograms.input_depth()
seqio.vocabularies.PassThroughVocabulary.size = 0
spectrograms.input_depth.spectrogram_config = %SPECTROGRAM_CONFIG
# ------------------- Network specification ------------------------------------
network.Transformer.config = @network.T5Config()
network.T5Config:
vocab_size = @vocabularies.num_embeddings()
dtype = 'float32'
emb_dim = 512
num_heads = 6
num_encoder_layers = 8
num_decoder_layers = 8
head_dim = 64
mlp_dim = 1024
mlp_activations = ('gelu', 'linear')
dropout_rate = 0.1
logits_via_embedding = False
vocabularies.num_embeddings.vocabulary = %OUTPUT_VOCABULARY