Spaces:
Build error
Build error
File size: 1,437 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# Defaults for precompile mode in main.py.
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - MIXTURE_OR_TASK_NAME
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS
# - MODEL_DIR: # automatically set when using xm_launch
#
# Commonly overridden options:
#
# - USE_CACHED_TASKS
# - BATCH_SIZE
# - PjitPartitioner.num_partitions
from __gin__ import dynamic_registration
import __main__ as train_script
import seqio
from t5x import gin_utils
from t5x import partitioning
from t5x import utils
from t5x import trainer
MODEL_DIR = %gin.REQUIRED
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
# Commonly overridden
USE_CACHED_TASKS = True
BATCH_SIZE = 128
# None always uses faster, hardware RNG
RANDOM_SEED = None
train_script.precompile:
model = %MODEL # imported from separate gin file
model_dir = %MODEL_DIR
train_dataset_cfg = @train/utils.DatasetConfig()
partitioner = @partitioning.PjitPartitioner()
random_seed = %RANDOM_SEED
partitioning.PjitPartitioner:
num_partitions = 1
model_parallel_submesh = None
backend = "tpu"
logical_axis_rules = @partitioning.standard_logical_axis_rules()
train/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'train'
batch_size = %BATCH_SIZE
shuffle = True
seed = None # use a new seed each run/restart
use_cached = %USE_CACHED_TASKS
pack = True
|