juancopi81's picture
Add t5x and mt3 models
b100e1c
raw
history blame contribute delete
No virus
2.05 kB
# Decoder-only model (Base) with 134307072 parameters.
from __gin__ import dynamic_registration
import seqio
from t5x import adafactor
from t5x import decoding
from t5x import models
from t5x.examples.decoder_only import network
# ------------------- Loss HParam ----------------------------------------------
Z_LOSS = 0.0001
LABEL_SMOOTHING = 0.0
# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF)
# the loss normalizing factor should be set to pretraining batch_size *
# target_token_length.
LOSS_NORMALIZING_FACTOR = None
# Dropout should be specified in the "run" files
DROPOUT_RATE = %gin.REQUIRED
# Vocabulary (shared by encoder and decoder)
VOCABULARY = @seqio.SentencePieceVocabulary()
seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"
# ------------------- Optimizer ------------------------------------------------
# `learning_rate` is set by `Trainer.learning_rate_fn`.
OPTIMIZER = @adafactor.Adafactor()
adafactor.Adafactor:
decay_rate = 0.8
step_offset = 0
logical_factor_rules = @adafactor.standard_logical_factor_rules()
# ------------------- Model ----------------------------------------------------
MODEL = @models.DecoderOnlyModel()
models.DecoderOnlyModel:
module = @network.DecoderWrapper()
vocabulary = %VOCABULARY
optimizer_def = %OPTIMIZER
decode_fn = @decoding.temperature_sample
z_loss = %Z_LOSS
label_smoothing = %LABEL_SMOOTHING
loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
decoding.temperature_sample:
temperature = 1.0
topk = 40
# ------------------- Network specification ------------------------------------
network.DecoderWrapper.config = @network.TransformerConfig()
network.TransformerConfig:
vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency
dtype = 'bfloat16'
emb_dim = 768
num_heads = 12
num_layers = 12
head_dim = 64
mlp_dim = 2048
mlp_activations = ('gelu', 'linear')
dropout_rate = %DROPOUT_RATE
logits_via_embedding = True