from __gin__ import dynamic_registration import __main__ as train_script import seqio from t5.data import mixtures from t5x import adafactor from t5x.examples.t5 import network from t5x import gin_utils from t5x import models from t5x import partitioning from t5x import trainer from t5x import utils import tasks # Macros: # ============================================================================== BATCH_SIZE = 32 DROPOUT_RATE = 0.1 EVAL_PERIOD = 1000 EVAL_STEPS = 20 EVALUATOR_NUM_EXAMPLES = None EVALUATOR_USE_MEMORY_CACHE = True INITIAL_CHECKPOINT_PATH = \ 'gs://north-t5x/pretrained_models/large/scandinavian3k_t5x_large/checkpoint_3000000' JSON_WRITE_N_RESULTS = None LABEL_SMOOTHING = 0.0 LOSS_NORMALIZING_FACTOR = None MIXTURE_OR_TASK_MODULE = None MIXTURE_OR_TASK_NAME = 'translate' MODEL = @models.EncoderDecoderModel() MODEL_DIR = 'gs://nb-t5x-us-central2/finetuned/scandi3_3stammer_v2_large' OPTIMIZER = @adafactor.Adafactor() RANDOM_SEED = 0 TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512} TRAIN_STEPS = 3100000 USE_CACHED_TASKS = False USE_HARDWARE_RNG = False VOCABULARY = @seqio.SentencePieceVocabulary() Z_LOSS = 0.0001 # Parameters for adafactor.Adafactor: # ============================================================================== adafactor.Adafactor.decay_rate = 0.8 adafactor.Adafactor.logical_factor_rules = \ @adafactor.standard_logical_factor_rules() adafactor.Adafactor.step_offset = 0 # Parameters for utils.CheckpointConfig: # ============================================================================== utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig() utils.CheckpointConfig.save = @utils.SaveCheckpointConfig() # Parameters for utils.create_learning_rate_scheduler: # ============================================================================== utils.create_learning_rate_scheduler.base_learning_rate = 0.001 utils.create_learning_rate_scheduler.factors = 'constant' utils.create_learning_rate_scheduler.warmup_steps = 1000 # Parameters for infer_eval/utils.DatasetConfig: # ============================================================================== infer_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE infer_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME infer_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE infer_eval/utils.DatasetConfig.pack = False infer_eval/utils.DatasetConfig.seed = 42 infer_eval/utils.DatasetConfig.shuffle = False infer_eval/utils.DatasetConfig.split = 'validation' infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS infer_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS # Parameters for train/utils.DatasetConfig: # ============================================================================== train/utils.DatasetConfig.batch_size = %BATCH_SIZE train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE train/utils.DatasetConfig.pack = True train/utils.DatasetConfig.seed = None train/utils.DatasetConfig.shuffle = True train/utils.DatasetConfig.split = 'train' train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS train/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS # Parameters for train_eval/utils.DatasetConfig: # ============================================================================== train_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE train_eval/utils.DatasetConfig.pack = True train_eval/utils.DatasetConfig.seed = 42 train_eval/utils.DatasetConfig.shuffle = False train_eval/utils.DatasetConfig.split = 'validation' train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS train_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS # Parameters for models.EncoderDecoderModel: # ============================================================================== models.EncoderDecoderModel.input_vocabulary = %VOCABULARY models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR models.EncoderDecoderModel.module = @network.Transformer() models.EncoderDecoderModel.optimizer_def = %OPTIMIZER models.EncoderDecoderModel.output_vocabulary = %VOCABULARY models.EncoderDecoderModel.z_loss = %Z_LOSS # Parameters for seqio.Evaluator: # ============================================================================== seqio.Evaluator.logger_cls = \ [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] seqio.Evaluator.num_examples = %EVALUATOR_NUM_EXAMPLES seqio.Evaluator.use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE # Parameters for seqio.JSONLogger: # ============================================================================== seqio.JSONLogger.write_n_results = %JSON_WRITE_N_RESULTS # Parameters for partitioning.PjitPartitioner: # ============================================================================== partitioning.PjitPartitioner.logical_axis_rules = \ @partitioning.standard_logical_axis_rules() partitioning.PjitPartitioner.model_parallel_submesh = None partitioning.PjitPartitioner.num_partitions = 1 # Parameters for utils.RestoreCheckpointConfig: # ============================================================================== utils.RestoreCheckpointConfig.dtype = 'float32' utils.RestoreCheckpointConfig.mode = 'specific' utils.RestoreCheckpointConfig.path = %INITIAL_CHECKPOINT_PATH # Parameters for utils.SaveCheckpointConfig: # ============================================================================== utils.SaveCheckpointConfig.dtype = 'float32' utils.SaveCheckpointConfig.keep = None utils.SaveCheckpointConfig.period = 10000 utils.SaveCheckpointConfig.save_dataset = False # Parameters for seqio.SentencePieceVocabulary: # ============================================================================== seqio.SentencePieceVocabulary.sentencepiece_model_file = \ 'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model' # Parameters for network.T5Config: # ============================================================================== network.T5Config.dropout_rate = %DROPOUT_RATE network.T5Config.dtype = 'bfloat16' network.T5Config.emb_dim = 1024 network.T5Config.head_dim = 64 network.T5Config.logits_via_embedding = False network.T5Config.mlp_activations = ('gelu', 'linear') network.T5Config.mlp_dim = 2816 network.T5Config.num_decoder_layers = 24 network.T5Config.num_encoder_layers = 24 network.T5Config.num_heads = 16 network.T5Config.vocab_size = 250112 # Parameters for train_script.train: # ============================================================================== train_script.train.checkpoint_cfg = @utils.CheckpointConfig() train_script.train.eval_period = %EVAL_PERIOD train_script.train.eval_steps = %EVAL_STEPS train_script.train.infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() train_script.train.inference_evaluator_cls = @seqio.Evaluator train_script.train.model = %MODEL train_script.train.model_dir = %MODEL_DIR train_script.train.partitioner = @partitioning.PjitPartitioner() train_script.train.random_seed = %RANDOM_SEED train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config train_script.train.total_steps = %TRAIN_STEPS train_script.train.train_dataset_cfg = @train/utils.DatasetConfig() train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() train_script.train.trainer_cls = @trainer.Trainer train_script.train.use_hardware_rng = %USE_HARDWARE_RNG # Parameters for trainer.Trainer: # ============================================================================== trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler() trainer.Trainer.num_microbatches = None # Parameters for network.Transformer: # ============================================================================== network.Transformer.config = @network.T5Config()