File size: 1,311 Bytes
4557487 08a2592 4f33d95 28d6fa5 4557487 4f33d95 4557487 dd620f6 81de315 dd620f6 e01a248 4557487 e01a248 dd620f6 08a2592 4557487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from __gin__ import dynamic_registration
import tasks
import __main__ as train_script
from t5.data import mixtures
from t5x import models
from t5x import partitioning
from t5x import utils
include "t5x/examples/t5/mt5/base.gin"
include "t5x/configs/runs/finetune.gin"
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
INITIAL_CHECKPOINT_PATH = %gin.REQUIRED #"gs://t5-data/pretrained_models/t5x/mt5_base/checkpoint_1000000"
TRAIN_STEPS = %gin.REQUIRED #1_010_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
USE_CACHED_TASKS = False
DROPOUT_RATE = 0.1
RANDOM_SEED = 0
#Fixing a small error
infer_eval/utils.DatasetConfig:
task_feature_lengths = %TASK_FEATURE_LENGTHS
#Saving every 1000 steps
utils.SaveCheckpointConfig:
period = 1000
# Pere: Only necessary if we load a t5 model. We can start with an t5x model here
# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained
# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be
# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1:
# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
# LOSS_NORMALIZING_FACTOR = 234496
# Might have to ba changed based on architecture
# partitioning.PjitPartitioner.num_partitions = 1
|