Spaces:
Build error
Build error
# mT5 Base model. | |
from __gin__ import dynamic_registration | |
import seqio | |
from t5x import adafactor | |
from t5x import models | |
from t5x.examples.t5 import network | |
# ------------------- Loss HParam ---------------------------------------------- | |
Z_LOSS = 0.0001 | |
LABEL_SMOOTHING = 0.0 | |
# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) | |
# the loss normalizing factor should be set to pretraining batch_size * | |
# target_token_length. | |
LOSS_NORMALIZING_FACTOR = None | |
# Dropout should be specified in the "run" files | |
DROPOUT_RATE = %gin.REQUIRED | |
# Vocabulary (shared by encoder and decoder) | |
VOCABULARY = @seqio.SentencePieceVocabulary() | |
seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model" | |
# ------------------- Optimizer ------------------------------------------------ | |
# `learning_rate` is set by `Trainer.learning_rate_fn`. | |
OPTIMIZER = @adafactor.Adafactor() | |
adafactor.Adafactor: | |
decay_rate = 0.8 | |
step_offset = 0 | |
logical_factor_rules = @adafactor.standard_logical_factor_rules() | |
# ------------------- Model ---------------------------------------------------- | |
MODEL = @models.EncoderDecoderModel() | |
models.EncoderDecoderModel: | |
module = @network.Transformer() | |
input_vocabulary = %VOCABULARY | |
output_vocabulary = %VOCABULARY | |
optimizer_def = %OPTIMIZER | |
z_loss = %Z_LOSS | |
label_smoothing = %LABEL_SMOOTHING | |
loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR | |
# ------------------- Network specification ------------------------------------ | |
network.Transformer.config = @network.T5Config() | |
network.T5Config: | |
vocab_size = 250112 # vocab size rounded to a multiple of 128 for TPU efficiency | |
dtype = 'bfloat16' | |
emb_dim = 768 | |
num_heads = 12 | |
num_encoder_layers = 12 | |
num_decoder_layers = 12 | |
head_dim = 64 | |
mlp_dim = 2048 | |
mlp_activations = ('gelu', 'linear') | |
dropout_rate = %DROPOUT_RATE | |
logits_via_embedding = False | |