Spaces:
Build error
Build error
# Defaults for precompile mode in main.py. | |
# | |
# You must also include a binding for MODEL. | |
# | |
# Required to be set: | |
# | |
# - MIXTURE_OR_TASK_NAME | |
# - TASK_FEATURE_LENGTHS | |
# - TRAIN_STEPS | |
# - MODEL_DIR: # automatically set when using xm_launch | |
# | |
# Commonly overridden options: | |
# | |
# - USE_CACHED_TASKS | |
# - BATCH_SIZE | |
# - PjitPartitioner.num_partitions | |
from __gin__ import dynamic_registration | |
import __main__ as train_script | |
import seqio | |
from t5x import gin_utils | |
from t5x import partitioning | |
from t5x import utils | |
from t5x import trainer | |
MODEL_DIR = %gin.REQUIRED | |
MIXTURE_OR_TASK_NAME = %gin.REQUIRED | |
TASK_FEATURE_LENGTHS = %gin.REQUIRED | |
# Commonly overridden | |
USE_CACHED_TASKS = True | |
BATCH_SIZE = 128 | |
# None always uses faster, hardware RNG | |
RANDOM_SEED = None | |
train_script.precompile: | |
model = %MODEL # imported from separate gin file | |
model_dir = %MODEL_DIR | |
train_dataset_cfg = @train/utils.DatasetConfig() | |
partitioner = @partitioning.PjitPartitioner() | |
random_seed = %RANDOM_SEED | |
partitioning.PjitPartitioner: | |
num_partitions = 1 | |
model_parallel_submesh = None | |
backend = "tpu" | |
logical_axis_rules = @partitioning.standard_logical_axis_rules() | |
train/utils.DatasetConfig: | |
mixture_or_task_name = %MIXTURE_OR_TASK_NAME | |
task_feature_lengths = %TASK_FEATURE_LENGTHS | |
split = 'train' | |
batch_size = %BATCH_SIZE | |
shuffle = True | |
seed = None # use a new seed each run/restart | |
use_cached = %USE_CACHED_TASKS | |
pack = True | |