File size: 1,828 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# A gin file to make the Transformer models tiny for faster local testing.
#
# When testing locally with CPU, there are a few things that we need.
# - tiny model size
# - small enough batch size
# - small sequence length
# - determinstic dataset pipeline
#
# This gin file adds such configs. To use this gin file, add it on top of the
# existing full-scale gin files. The ordering of the gin file matters. So this
# should be added after all the other files are added to override the same
# configurables.

from __gin__ import dynamic_registration

from t5x import partitioning
from t5x import trainer
from t5x import utils
from t5x.examples.t5 import network

import __main__ as train_script

train_script.train.random_seed = 42  # dropout seed
train/utils.DatasetConfig.seed = 42  # dataset seed

TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 7}
LABEL_SMOOTHING = 0.0

# Network specification overrides
network.Transformer.config = @network.T5Config()
network.T5Config:
  vocab_size = 32128  # vocab size rounded to a multiple of 128 for TPU efficiency
  dtype = 'bfloat16'
  emb_dim = 8
  num_heads = 4
  num_encoder_layers = 2
  num_decoder_layers = 2
  head_dim = 3
  mlp_dim = 16
  mlp_activations = ('gelu', 'linear')
  dropout_rate = 0.0
  logits_via_embedding = False
  scan_layers = True
  remat_policy = 'minimal'

TRAIN_STEPS = 3

train/utils.DatasetConfig:
  batch_size = 8
  shuffle = False

train_eval/utils.DatasetConfig.batch_size = 8

train_script.train:
  eval_period = 3
  eval_steps = 3

trainer.Trainer.num_microbatches = 0
partitioning.PjitPartitioner:
  num_partitions = 1
  model_parallel_submesh = None

utils.CheckpointConfig:
  restore = None

infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS


# DISABLE INFERENCE EVAL
# train_script.train.infer_eval_dataset_cfg = None