Spaces:
Build error
Build error
# A gin file to make the Transformer models tiny for faster local testing. | |
# | |
# When testing locally with CPU, there are a few things that we need. | |
# - tiny model size | |
# - small enough batch size | |
# - small sequence length | |
# - determinstic dataset pipeline | |
# | |
# This gin file adds such configs. To use this gin file, add it on top of the | |
# existing full-scale gin files. The ordering of the gin file matters. So this | |
# should be added after all the other files are added to override the same | |
# configurables. | |
from __gin__ import dynamic_registration | |
from t5x import partitioning | |
from t5x import trainer | |
from t5x import utils | |
from t5x.examples.t5 import network | |
import __main__ as train_script | |
train_script.train.random_seed = 42 # dropout seed | |
train/utils.DatasetConfig.seed = 42 # dataset seed | |
TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 16} | |
LABEL_SMOOTHING = 0.0 | |
# Network specification overrides | |
network.Transformer.config = @network.T5Config() | |
network.T5Config: | |
dtype = 'float32' | |
emb_dim = 8 | |
num_heads = 4 | |
num_encoder_layers = 2 | |
num_decoder_layers = 2 | |
head_dim = 3 | |
mlp_dim = 16 | |
mlp_activations = ('gelu', 'linear') | |
dropout_rate = 0.0 | |
logits_via_embedding = False | |
TRAIN_STEPS = 3 | |
train/utils.DatasetConfig: | |
batch_size = 8 | |
shuffle = False | |
train_eval/utils.DatasetConfig.batch_size = 8 | |
train_script.train: | |
eval_period = 3 | |
eval_steps = 3 | |
trainer.Trainer.num_microbatches = 0 | |
partitioning.PjitPartitioner: | |
num_partitions = 1 | |
model_parallel_submesh = None | |
utils.CheckpointConfig: | |
restore = None | |
infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS | |