File size: 2,089 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# T5.1.1 Small model.
from __gin__ import dynamic_registration

from mt3 import models
from mt3 import network
from mt3 import spectrograms
from mt3 import vocabularies
import seqio
from t5x import adafactor

# ------------------- Loss HParam ----------------------------------------------
Z_LOSS = 0.0001
LABEL_SMOOTHING = 0.0
LOSS_NORMALIZING_FACTOR = None
models.ContinuousInputsEncoderDecoderModel:
  z_loss = %Z_LOSS
  label_smoothing = %LABEL_SMOOTHING
  loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR

# Output vocabulary
VOCAB_CONFIG = %gin.REQUIRED
OUTPUT_VOCABULARY = @vocabularies.vocabulary_from_codec()
vocabularies.vocabulary_from_codec.codec = @vocabularies.build_codec()
vocabularies.build_codec.vocab_config = %VOCAB_CONFIG

# ------------------- Optimizer ------------------------------------------------
# `learning_rate` is set by `Trainer.learning_rate_fn`.
OPTIMIZER = @adafactor.Adafactor()
adafactor.Adafactor:
  decay_rate = 0.8
  step_offset = 0
  logical_factor_rules = @adafactor.standard_logical_factor_rules()

# ------------------- Model ----------------------------------------------------
SPECTROGRAM_CONFIG = @spectrograms.SpectrogramConfig()
MODEL = @models.ContinuousInputsEncoderDecoderModel()
models.ContinuousInputsEncoderDecoderModel:
  module = @network.Transformer()
  input_vocabulary = @seqio.vocabularies.PassThroughVocabulary()
  output_vocabulary = %OUTPUT_VOCABULARY
  optimizer_def = %OPTIMIZER
  input_depth = @spectrograms.input_depth()
seqio.vocabularies.PassThroughVocabulary.size = 0
spectrograms.input_depth.spectrogram_config = %SPECTROGRAM_CONFIG

# ------------------- Network specification ------------------------------------
network.Transformer.config = @network.T5Config()
network.T5Config:
  vocab_size = @vocabularies.num_embeddings()
  dtype = 'float32'
  emb_dim = 512
  num_heads = 6
  num_encoder_layers = 8
  num_decoder_layers = 8
  head_dim = 64
  mlp_dim = 1024
  mlp_activations = ('gelu', 'linear')
  dropout_rate = 0.1
  logits_via_embedding = False
vocabularies.num_embeddings.vocabulary = %OUTPUT_VOCABULARY