File size: 1,930 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# T5.1.1 Base model.
from __gin__ import dynamic_registration

import seqio
from t5x import adafactor
from t5x import models
from t5x.examples.t5 import network

# ------------------- Loss HParam ----------------------------------------------
Z_LOSS = 0.0001
LABEL_SMOOTHING = 0.0
# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF)
# the loss normalizing factor should be set to pretraining batch_size *
# target_token_length.
LOSS_NORMALIZING_FACTOR = None
# Dropout should be specified in the "run" files
DROPOUT_RATE = %gin.REQUIRED

# Vocabulary (shared by encoder and decoder)
VOCABULARY = @seqio.SentencePieceVocabulary()
seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"

# ------------------- Optimizer ------------------------------------------------
# `learning_rate` is set by `Trainer.learning_rate_fn`.
OPTIMIZER = @adafactor.Adafactor()
adafactor.Adafactor:
  decay_rate = 0.8
  step_offset = 0
  logical_factor_rules = @adafactor.standard_logical_factor_rules()

# ------------------- Model ----------------------------------------------------
MODEL = @models.EncoderDecoderModel()
models.EncoderDecoderModel:
  module = @network.Transformer()
  input_vocabulary = %VOCABULARY
  output_vocabulary = %VOCABULARY
  optimizer_def = %OPTIMIZER
  z_loss = %Z_LOSS
  label_smoothing = %LABEL_SMOOTHING
  loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR

# ------------------- Network specification ------------------------------------
network.Transformer.config = @network.T5Config()
network.T5Config:
  vocab_size = 32128  # vocab size rounded to a multiple of 128 for TPU efficiency
  dtype = 'bfloat16'
  emb_dim = 768
  num_heads = 12
  num_encoder_layers = 12
  num_decoder_layers = 12
  head_dim = 64
  mlp_dim = 2048
  mlp_activations = ('gelu', 'linear')
  dropout_rate = %DROPOUT_RATE
  logits_via_embedding = False