juancopi81's picture
Add t5x and mt3 models
b100e1c
raw
history blame
1.83 kB
# A gin file to make the Transformer models tiny for faster local testing.
#
# When testing locally with CPU, there are a few things that we need.
# - tiny model size
# - small enough batch size
# - small sequence length
# - determinstic dataset pipeline
#
# This gin file adds such configs. To use this gin file, add it on top of the
# existing full-scale gin files. The ordering of the gin file matters. So this
# should be added after all the other files are added to override the same
# configurables.
from __gin__ import dynamic_registration
from t5x import partitioning
from t5x import trainer
from t5x import utils
from t5x.examples.t5 import network
import __main__ as train_script
train_script.train.random_seed = 42 # dropout seed
train/utils.DatasetConfig.seed = 42 # dataset seed
TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 7}
LABEL_SMOOTHING = 0.0
# Network specification overrides
network.Transformer.config = @network.T5Config()
network.T5Config:
vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency
dtype = 'bfloat16'
emb_dim = 8
num_heads = 4
num_encoder_layers = 2
num_decoder_layers = 2
head_dim = 3
mlp_dim = 16
mlp_activations = ('gelu', 'linear')
dropout_rate = 0.0
logits_via_embedding = False
scan_layers = True
remat_policy = 'minimal'
TRAIN_STEPS = 3
train/utils.DatasetConfig:
batch_size = 8
shuffle = False
train_eval/utils.DatasetConfig.batch_size = 8
train_script.train:
eval_period = 3
eval_steps = 3
trainer.Trainer.num_microbatches = 0
partitioning.PjitPartitioner:
num_partitions = 1
model_parallel_submesh = None
utils.CheckpointConfig:
restore = None
infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
# DISABLE INFERENCE EVAL
# train_script.train.infer_eval_dataset_cfg = None