|
|
|
|
|
from __gin__ import dynamic_registration |
|
|
|
import seqio |
|
from t5x import adafactor |
|
from t5x import models |
|
import tasks |
|
|
|
ARCHITECTURE = %gin.REQUIRED |
|
|
|
include 'flaxformer/t5x/configs/longt5/architectures/longt5_1_1_flaxformer.gin' |
|
|
|
include 't5x/configs/runs/pretrain.gin' |
|
|
|
|
|
MIXTURE_OR_TASK_NAME = "ncc_scandinavian_span_corruption_stream" |
|
TASK_FEATURE_LENGTHS = {"inputs": 4048, "targets": 910} |
|
|
|
BATCH_SIZE=32 |
|
TRAIN_STEPS = 1_000_000 |
|
DROPOUT_RATE = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
NUM_HEADS = 12 |
|
NUM_ENCODER_LAYERS = 12 |
|
NUM_DECODER_LAYERS = 12 |
|
HEAD_DIM = 64 |
|
EMBED_DIM = 768 |
|
MLP_DIM = 2048 |
|
|
|
|
|
Z_LOSS = 0.0001 |
|
LABEL_SMOOTHING = 0.0 |
|
LOSS_NORMALIZING_FACTOR = None |
|
|
|
|
|
VOCABULARY = @seqio.SentencePieceVocabulary() |
|
seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" |
|
NUM_EMBEDDINGS = 32128 |
|
|
|
|
|
|
|
OPTIMIZER = @adafactor.Adafactor() |
|
adafactor.Adafactor: |
|
decay_rate = 0.8 |
|
step_offset = 0 |
|
|
|
|
|
MODEL = @models.EncoderDecoderModel() |
|
models.EncoderDecoderModel: |
|
module = %ARCHITECTURE |
|
input_vocabulary = %VOCABULARY |
|
output_vocabulary = %VOCABULARY |
|
optimizer_def = %OPTIMIZER |
|
z_loss = %Z_LOSS |
|
label_smoothing = %LABEL_SMOOTHING |
|
loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR |
|
|