juancopi81's picture
Add t5x and mt3 models
b100e1c
raw
history blame
1.44 kB
# Defaults for precompile mode in main.py.
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - MIXTURE_OR_TASK_NAME
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS
# - MODEL_DIR: # automatically set when using xm_launch
#
# Commonly overridden options:
#
# - USE_CACHED_TASKS
# - BATCH_SIZE
# - PjitPartitioner.num_partitions
from __gin__ import dynamic_registration
import __main__ as train_script
import seqio
from t5x import gin_utils
from t5x import partitioning
from t5x import utils
from t5x import trainer
MODEL_DIR = %gin.REQUIRED
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
# Commonly overridden
USE_CACHED_TASKS = True
BATCH_SIZE = 128
# None always uses faster, hardware RNG
RANDOM_SEED = None
train_script.precompile:
model = %MODEL # imported from separate gin file
model_dir = %MODEL_DIR
train_dataset_cfg = @train/utils.DatasetConfig()
partitioner = @partitioning.PjitPartitioner()
random_seed = %RANDOM_SEED
partitioning.PjitPartitioner:
num_partitions = 1
model_parallel_submesh = None
backend = "tpu"
logical_axis_rules = @partitioning.standard_logical_axis_rules()
train/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'train'
batch_size = %BATCH_SIZE
shuffle = True
seed = None # use a new seed each run/restart
use_cached = %USE_CACHED_TASKS
pack = True