aapot commited on
Commit
8fdc728
1 Parent(s): 0b9efda

Add UL2 code

Browse files
.gitattributes CHANGED
@@ -30,3 +30,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
33
+ checkpoint*/** filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
1
+ __pycache__/
config.gin ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import __main__ as train_script
3
+ import seqio
4
+ from t5x import adafactor
5
+ from t5x.examples.t5 import network
6
+ from t5x import gin_utils
7
+ from t5x import models
8
+ from t5x import partitioning
9
+ from t5x import trainer
10
+ from t5x import utils
11
+ import tasks
12
+
13
+ # Macros:
14
+ # ==============================================================================
15
+ BATCH_SIZE = 256
16
+ DROPOUT_RATE = 0.0
17
+ LABEL_SMOOTHING = 0.0
18
+ LOSS_NORMALIZING_FACTOR = None
19
+ MIXTURE_OR_TASK_MODULE = None
20
+ MIXTURE_OR_TASK_NAME = 'pretrain_finnish_ul2'
21
+ MODEL = @models.EncoderDecoderModel()
22
+ MODEL_DIR = '/researchdisk/ul2-small-nl16-finnish'
23
+ OPTIMIZER = @adafactor.Adafactor()
24
+ RANDOM_SEED = None
25
+ SHUFFLE_TRAIN_EXAMPLES = True
26
+ TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512}
27
+ TRAIN_STEPS = 500000
28
+ USE_CACHED_TASKS = False
29
+ USE_HARDWARE_RNG = False
30
+ VOCABULARY = @seqio.SentencePieceVocabulary()
31
+ Z_LOSS = 0.0001
32
+
33
+ # Parameters for adafactor.Adafactor:
34
+ # ==============================================================================
35
+ adafactor.Adafactor.decay_rate = 0.8
36
+ adafactor.Adafactor.logical_factor_rules = \
37
+ @adafactor.standard_logical_factor_rules()
38
+ adafactor.Adafactor.step_offset = 0
39
+
40
+ # Parameters for utils.CheckpointConfig:
41
+ # ==============================================================================
42
+ utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
43
+ utils.CheckpointConfig.save = @utils.SaveCheckpointConfig()
44
+
45
+ # Parameters for utils.create_learning_rate_scheduler:
46
+ # ==============================================================================
47
+ utils.create_learning_rate_scheduler.base_learning_rate = 1.0
48
+ utils.create_learning_rate_scheduler.factors = 'constant * rsqrt_decay'
49
+ utils.create_learning_rate_scheduler.warmup_steps = 10000
50
+
51
+ # Parameters for train/utils.DatasetConfig:
52
+ # ==============================================================================
53
+ train/utils.DatasetConfig.batch_size = %BATCH_SIZE
54
+ train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
55
+ train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
56
+ train/utils.DatasetConfig.pack = True
57
+ train/utils.DatasetConfig.seed = None
58
+ train/utils.DatasetConfig.shuffle = %SHUFFLE_TRAIN_EXAMPLES
59
+ train/utils.DatasetConfig.split = 'train'
60
+ train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
61
+ train/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
62
+
63
+ # Parameters for train_eval/utils.DatasetConfig:
64
+ # ==============================================================================
65
+ train_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE
66
+ train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
67
+ train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
68
+ train_eval/utils.DatasetConfig.pack = True
69
+ train_eval/utils.DatasetConfig.seed = 42
70
+ train_eval/utils.DatasetConfig.shuffle = False
71
+ train_eval/utils.DatasetConfig.split = 'validation'
72
+ train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
73
+ train_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
74
+
75
+ # Parameters for models.EncoderDecoderModel:
76
+ # ==============================================================================
77
+ models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
78
+ models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
79
+ models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
80
+ models.EncoderDecoderModel.module = @network.Transformer()
81
+ models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
82
+ models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
83
+ models.EncoderDecoderModel.z_loss = %Z_LOSS
84
+
85
+ # Parameters for partitioning.PjitPartitioner:
86
+ # ==============================================================================
87
+ partitioning.PjitPartitioner.logical_axis_rules = \
88
+ @partitioning.standard_logical_axis_rules()
89
+ partitioning.PjitPartitioner.model_parallel_submesh = None
90
+ partitioning.PjitPartitioner.num_partitions = 1
91
+
92
+ # Parameters for utils.RestoreCheckpointConfig:
93
+ # ==============================================================================
94
+ utils.RestoreCheckpointConfig.path = []
95
+
96
+ # Parameters for utils.SaveCheckpointConfig:
97
+ # ==============================================================================
98
+ utils.SaveCheckpointConfig.dtype = 'float32'
99
+ utils.SaveCheckpointConfig.keep = 10
100
+ utils.SaveCheckpointConfig.period = 10000
101
+ utils.SaveCheckpointConfig.save_dataset = False
102
+
103
+ # Parameters for seqio.SentencePieceVocabulary:
104
+ # ==============================================================================
105
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = 'spiece.model'
106
+
107
+ # Parameters for network.T5Config:
108
+ # ==============================================================================
109
+ network.T5Config.dropout_rate = %DROPOUT_RATE
110
+ network.T5Config.dtype = 'bfloat16'
111
+ network.T5Config.emb_dim = 512
112
+ network.T5Config.head_dim = 64
113
+ network.T5Config.logits_via_embedding = False
114
+ network.T5Config.mlp_activations = ('gelu', 'linear')
115
+ network.T5Config.mlp_dim = 2048
116
+ network.T5Config.num_decoder_layers = 16
117
+ network.T5Config.num_encoder_layers = 16
118
+ network.T5Config.num_heads = 8
119
+ network.T5Config.vocab_size = 32128
120
+
121
+ # Parameters for train_script.train:
122
+ # ==============================================================================
123
+ train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
124
+ train_script.train.eval_period = 10000
125
+ train_script.train.eval_steps = 20
126
+ train_script.train.infer_eval_dataset_cfg = None
127
+ train_script.train.model = %MODEL
128
+ train_script.train.model_dir = %MODEL_DIR
129
+ train_script.train.partitioner = @partitioning.PjitPartitioner()
130
+ train_script.train.random_seed = %RANDOM_SEED
131
+ train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config
132
+ train_script.train.total_steps = %TRAIN_STEPS
133
+ train_script.train.train_dataset_cfg = @train/utils.DatasetConfig()
134
+ train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
135
+ train_script.train.trainer_cls = @trainer.Trainer
136
+ train_script.train.use_hardware_rng = %USE_HARDWARE_RNG
137
+
138
+ # Parameters for trainer.Trainer:
139
+ # ==============================================================================
140
+ trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler()
141
+ trainer.Trainer.num_microbatches = None
142
+
143
+ # Parameters for network.Transformer:
144
+ # ==============================================================================
145
+ network.Transformer.config = @network.T5Config()
model-info.txt ADDED
The diff for this file is too large to render. See raw diff
small_nl16.gin ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5.1.1 Efficient small nl16 model.
2
+
3
+ import seqio
4
+ include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model.
5
+
6
+ # ------------------- Network specification overrides --------------------------
7
+ network.Transformer.config = @network.T5Config()
8
+ network.T5Config:
9
+ emb_dim = 512
10
+ num_heads = 8
11
+ num_encoder_layers = 16
12
+ num_decoder_layers = 16
13
+ head_dim = 64
14
+ mlp_dim = 2048
15
+
16
+ # ------------------- Model specification overrides --------------------------
17
+ VOCABULARY = @seqio.SentencePieceVocabulary()
18
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = "spiece.model"
19
+
20
+ MODEL = @models.EncoderDecoderModel()
21
+ models.EncoderDecoderModel:
22
+ input_vocabulary = %VOCABULARY
23
+ output_vocabulary = %VOCABULARY
small_nl16_pretrain.gin ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Register necessary SeqIO Tasks/Mixtures.
2
+ from __gin__ import dynamic_registration
3
+ from t5x import utils
4
+ from t5x import partitioning
5
+ import tasks
6
+ import __main__ as train_script
7
+
8
+ include 'small_nl16.gin'
9
+ include 't5x/configs/runs/pretrain.gin'
10
+
11
+
12
+ # ------------------- Training specification overrides --------------------------
13
+ train_script.train:
14
+ eval_period = 10000
15
+
16
+ utils.SaveCheckpointConfig:
17
+ period = 10000
18
+ keep = 10
19
+
20
+ MIXTURE_OR_TASK_NAME = "pretrain_finnish_ul2"
21
+ USE_CACHED_TASKS = False
22
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
23
+ TRAIN_STEPS = 500000
24
+ DROPOUT_RATE = 0.0
25
+ BATCH_SIZE = 256
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b39375bd095c6d78f90949701049adc8f2a8fc1350c8c569c68ba079c80bda3f
3
+ size 821929
spiece.vocab ADDED
The diff for this file is too large to render. See raw diff
start_train.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set train hyperparams
2
+ unset LD_PRELOAD
3
+
4
+ PROJECT_DIR="/researchdisk/ul2-small-nl16-finnish"
5
+ T5X_DIR=${HOME}"/t5x" # directory where the t5x is cloned.
6
+ MODEL_DIR="/researchdisk/ul2-small-nl16-finnish"
7
+ export PYTHONPATH=${PROJECT_DIR}
8
+
9
+ python3 ${T5X_DIR}/t5x/train.py \
10
+ --gin_search_paths=${PROJECT_DIR} \
11
+ --gin_file="small_nl16_pretrain.gin" \
12
+ --gin.MODEL_DIR=\"${MODEL_DIR}\"
tasks.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import seqio
3
+ import tensorflow as tf
4
+ import t5.data
5
+ from datasets import load_dataset, load_from_disk
6
+ from t5.data import postprocessors
7
+ from t5.data import preprocessors
8
+ from t5.evaluation import metrics
9
+ from seqio import FunctionDataSource, utils
10
+
11
+ from ul2_objective import ul2_objective
12
+
13
+ # values from UL2 paper https://arxiv.org/pdf/2205.05131.pdf chapter 3.1.2 table 1
14
+ R_DENOISER_SPAN_LENGTHS = [3.0, 8.0]
15
+ X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0]
16
+ R_DENOISER_CORRUPT_RATES = [0.15, 0.15]
17
+ X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5]
18
+
19
+ R_DENOISER_TOKEN_PREFIX = '[NLU]'
20
+ X_DENOISER_TOKEN_PREFIX = '[NLG]'
21
+ S_DENOISER_TOKEN_PREFIX = '[S2S]'
22
+
23
+ TaskRegistry = seqio.TaskRegistry
24
+
25
+ vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0)
26
+
27
+ DEFAULT_OUTPUT_FEATURES = {
28
+ "inputs": seqio.Feature(
29
+ vocabulary=vocabulary, add_eos=True,
30
+ required=False),
31
+ "targets": seqio.Feature(
32
+ vocabulary=vocabulary, add_eos=True)
33
+ }
34
+
35
+
36
+ def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None):
37
+ if shuffle:
38
+ if seed:
39
+ dataset = dataset.shuffle(seed=seed)
40
+ else:
41
+ dataset = dataset.shuffle()
42
+ while True:
43
+ for item in dataset[str(split)]:
44
+ if item[column] is not None:
45
+ yield item[column]
46
+
47
+
48
+ def dataset_fn(split, shuffle_files, seed=None, dataset=None):
49
+ return tf.data.Dataset.from_generator(
50
+ functools.partial(gen_dataset, split, shuffle_files,
51
+ seed, dataset=dataset),
52
+ output_signature=tf.TensorSpec(
53
+ shape=(), dtype=tf.string, name=dataset_name)
54
+ )
55
+
56
+
57
+ @utils.map_over_dataset
58
+ def target_to_key(x, key_map, target_key):
59
+ """Assign the value from the dataset to target_key in key_map"""
60
+ return {**key_map, target_key: x}
61
+
62
+
63
+ dataset_name = "/researchdisk/lm_training_dataset_full"
64
+ dataset_params = {"from_disk_path": dataset_name}
65
+
66
+ if "from_disk_path" in dataset_params:
67
+ dataset = load_from_disk(dataset_params.get("from_disk_path"))
68
+ else:
69
+ dataset = load_dataset(**dataset_params)
70
+
71
+ dataset_shapes = {"train": dataset["train"].num_rows,
72
+ "validation": dataset["validation"].num_rows}
73
+
74
+ TaskRegistry.add(
75
+ "pretrain_finnish_ul2",
76
+ source=seqio.FunctionDataSource(
77
+ dataset_fn=functools.partial(dataset_fn, dataset=dataset),
78
+ splits=("train", "validation"),
79
+ caching_permitted=False,
80
+ num_input_examples=dataset_shapes,
81
+ ),
82
+ preprocessors=[
83
+ functools.partial(
84
+ target_to_key, key_map={
85
+ "inputs": None,
86
+ "targets": None,
87
+ }, target_key="targets"),
88
+ seqio.preprocessors.tokenize,
89
+ functools.partial(
90
+ ul2_objective,
91
+ shard_ds=False,
92
+ use_prefix_lm_task=True, # use S-denoising
93
+ rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)]*len(R_DENOISER_SPAN_LENGTHS) + [
94
+ 0.4 / len(X_DENOISER_SPAN_LENGTHS)]*len(X_DENOISER_SPAN_LENGTHS) + [0.2], # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
95
+ mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
96
+ noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
97
+ optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]*len(R_DENOISER_SPAN_LENGTHS) + [
98
+ X_DENOISER_TOKEN_PREFIX]*len(X_DENOISER_SPAN_LENGTHS) + [S_DENOISER_TOKEN_PREFIX],
99
+ reserved_for_packing=1, # make room for task prefix token
100
+ ),
101
+ seqio.preprocessors.append_eos_after_trim,
102
+ ],
103
+ output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
104
+ metric_fns=[metrics.accuracy]
105
+ )
train/events.out.tfevents.1665907259.t1v-n-12f94ad0-w-0.185363.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c87efc6802f4bdc5c6ba5fd3d8e13c0856ef1ce5854183f4d07efa3a3489342
3
+ size 5950
train_sentencepiece.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+
3
+ spm.SentencePieceTrainer.train(input="/researchdisk/training_dataset_sentences/train.txt", model_prefix='spiece', vocab_size=32000, character_coverage=1.0,
4
+ pad_id=0, unk_id=2, eos_id=1, bos_id=-1,
5
+ user_defined_symbols=['[NLU]', '[NLG]', '[S2S]'],
6
+ train_extremely_large_corpus=True,
7
+ num_threads=96, input_sentence_size=50000000, shuffle_input_sentence=True)
training_eval/pretrain_finnish_ul2/events.out.tfevents.1665907259.t1v-n-12f94ad0-w-0.185363.1.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cac17ca4282fe1cb0aafdd1122b0c6fda6d225a5fc62dbb60c311ee2f4c67d7b
3
+ size 40
ul2_objective.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import tensorflow as tf
3
+ import seqio
4
+ import t5.data
5
+ from typing import Optional, Sequence
6
+
7
+ # found this function and modified from https://github.com/GoogleCloudPlatform/t5x-on-vertex-ai/blob/main/tasks/custom_tasks.py#L78
8
+ # UL2 paper appendix code missed this function
9
+ def prepend_prompt(dataset: tf.data.Dataset,
10
+ output_features: seqio.preprocessors.OutputFeaturesType,
11
+ sequence_length: Optional[
12
+ seqio.preprocessors.SequenceLengthType] = None,
13
+ prompt_mode: str = "",
14
+ key: str = "inputs",
15
+ mode: str = "") -> tf.data.Dataset:
16
+ """Prepends a prompt at the beginning of an input sequence."""
17
+ del sequence_length
18
+ if prompt_mode and mode:
19
+ # output_features may not have inputs key
20
+ out_keys = list(output_features.keys())
21
+ prompt_tokens = output_features[out_keys[0]
22
+ ].vocabulary.encode_tf(prompt_mode)
23
+
24
+ def add_to_inputs(x):
25
+ x[key] = tf.concat([prompt_tokens, x[key]], axis=0)
26
+ return x
27
+
28
+ dataset = dataset.map(add_to_inputs)
29
+ return dataset
30
+
31
+ # modified from t5.data.preprocessors because output_features may not have inputs key
32
+ def split_tokens_to_inputs_length(dataset, sequence_length,
33
+ output_features, **kwargs):
34
+ max_tokens = sequence_length['inputs']
35
+ # output_features may not have inputs key
36
+ out_keys = list(output_features.keys())
37
+ if output_features[out_keys[0]].add_eos:
38
+ # Leave room to insert an EOS token.
39
+ max_tokens -= 1
40
+
41
+ return t5.data.preprocessors.split_tokens(dataset, max_tokens_per_segment=max_tokens, **kwargs)
42
+
43
+ # modified from t5.data.preprocessors because output_features may not have inputs key
44
+ def prefix_lm(dataset, sequence_length, output_features):
45
+ """Prefix language modeling objective used in Raffel et al. 2019."""
46
+ ds = dataset
47
+ ds = t5.data.preprocessors.select_random_chunk(ds, output_features=output_features,
48
+ feature_key='targets', max_length=65536)
49
+ ds = split_tokens_to_inputs_length(ds, output_features=output_features,
50
+ sequence_length=sequence_length)
51
+ ds = t5.data.preprocessors.denoise(
52
+ ds,
53
+ output_features,
54
+ inputs_fn=t5.data.preprocessors.drop_nonnoise_tokens,
55
+ targets_fn=t5.data.preprocessors.drop_noise_tokens,
56
+ noise_density=0.5,
57
+ noise_mask_fn=t5.data.preprocessors.random_prefix_noise_mask,
58
+ )
59
+ return ds
60
+
61
+ # copied from UL2 paper https://arxiv.org/pdf/2205.05131.pdf appendix chapter 9.2
62
+ # note: modified to use the prefix_lm() from above instead of the default t5.data.preprocessors.prefix_lm() because output_features may not have inputs key
63
+ def ul2_objective(dataset: tf.data.Dataset,
64
+ sequence_length: seqio.preprocessors.SequenceLengthType,
65
+ output_features: seqio.preprocessors.OutputFeaturesType,
66
+ use_prefix_lm_task: bool = False,
67
+ rates: Optional[Sequence[float]] = None,
68
+ mean_noise_span_lengths: Sequence[float] = (3.0,),
69
+ noise_densities: Sequence[float] = (0.15,),
70
+ shard_ds: bool = True,
71
+ optional_task_prefixes: Optional[Sequence[str]] = None,
72
+ input_feature_key: str = "inputs",
73
+ merge_examples_to_reduce_padding: bool = True,
74
+ reserved_for_packing: bool = None,
75
+ seed: int = 7) -> tf.data.Dataset:
76
+ """UL2-like pre-training objectives.
77
+ This preprocessor amounts to calling the 'span_corruption' function several
78
+ times with different values of 'noise_density' and 'mean_noise_span_length'.
79
+ We either shard or copy the dataset, then apply each function to each shard.
80
+ Add S-denoising (prefixLM) using use_prefix_lm_task.
81
+ Args:
82
+ dataset: A tf.data.Dataset with dictionaries containing the key 'input_feature_key'.
83
+ sequence_length: dict mapping of feature key to int length for that feature.
84
+ output_features: mapping of keys to features.
85
+ use_prefix_lm_task: <bool> If True, include PrefixLM in the task mix.
86
+ rates: <Optional<List<float>> List of rates per task. If None, tasks are sampled uniformly.
87
+ mean_noise_span_lengths: List of mean number of tokens per masked span per example.
88
+ noise_densities: List of what fraction of the tokens to mask.
89
+ shard_ds: <bool> If True, shard dataset per objective.
90
+ optional_task_prefixes: <Optional<list<str>> Strings to prepend for each corruption scheme. NOTE: If including prefixLM task, it must be the last prefix.
91
+ input_feature_key: which feature to use from the dataset as the input text tokens.
92
+ merge_examples_to_reduce_padding: if True, combines multiple input examples to reduce padding.
93
+ reserved_for_packing: if specified, reduces the desired inputs length by the specified amount to enable multiple examples to be packed together downstream.
94
+ seed: tf.int64 for controlling the random choice of spans.
95
+ Returns:
96
+ a dataset
97
+ """
98
+
99
+ if optional_task_prefixes: # Ensure each task has a prefix.
100
+ num_tasks = len(noise_densities) + int(use_prefix_lm_task)
101
+ valid_number_of_prefixes = num_tasks == len(optional_task_prefixes)
102
+ if not valid_number_of_prefixes:
103
+ raise ValueError(
104
+ "Number of task prefixes must match number of tasks.")
105
+ inputs_length = sequence_length[input_feature_key]
106
+ input_lengths, targets_lengths = [], []
107
+ sequence_lengths = {x: y for x, y in sequence_length.items()}
108
+ if reserved_for_packing:
109
+ inputs_length -= reserved_for_packing
110
+ for x, y in sequence_length.items():
111
+ sequence_lengths[x] = y - reserved_for_packing
112
+ hyperparams = list(zip(mean_noise_span_lengths, noise_densities))
113
+ for mean_noise_span_length, noise_density in hyperparams:
114
+ input_length, targets_length = t5.data.preprocessors.random_spans_helper(
115
+ extra_tokens_per_span_inputs=1,
116
+ extra_tokens_per_span_targets=1,
117
+ inputs_length=inputs_length,
118
+ mean_noise_span_length=mean_noise_span_length,
119
+ noise_density=noise_density)
120
+ input_lengths.append(input_length)
121
+ targets_lengths.append(targets_length)
122
+
123
+ if sequence_length["targets"] < targets_length:
124
+ upper_bound = max(targets_lengths)
125
+ raise ValueError(
126
+ f'Expected max targets length for span corruption ({upper_bound}) is '
127
+ f'greater than configured targets length '
128
+ f"({sequence_length['targets']})")
129
+ ds = dataset
130
+ ds = t5.data.preprocessors.select_random_chunk(
131
+ ds,
132
+ output_features=output_features,
133
+ feature_key="targets",
134
+ max_length=65536)
135
+ if merge_examples_to_reduce_padding:
136
+ ds = t5.data.preprocessors.reduce_concat_tokens(
137
+ ds, feature_key="targets", batch_size=128)
138
+ num_shards = len(input_lengths) + int(use_prefix_lm_task)
139
+ if shard_ds:
140
+ ds_shards = [ds.shard(num_shards, i) for i in range(num_shards)]
141
+ else:
142
+ ds_shards = [ds for _ in range(num_shards)]
143
+ processed_ds = []
144
+ hyperparams = zip(input_lengths, hyperparams, range(num_shards))
145
+ for input_length, (noise_span_length, noise_density), i in hyperparams:
146
+ ds = ds_shards[i]
147
+ ds = t5.data.preprocessors.split_tokens(
148
+ ds,
149
+ feature_key="targets",
150
+ min_tokens_per_segment=None,
151
+ max_tokens_per_segment=input_length)
152
+ ds = t5.data.preprocessors.denoise(
153
+ ds,
154
+ output_features,
155
+ inputs_fn=t5.data.preprocessors.noise_span_to_unique_sentinel,
156
+ targets_fn=t5.data.preprocessors.nonnoise_span_to_unique_sentinel,
157
+ noise_density=noise_density,
158
+ noise_mask_fn=functools.partial(
159
+ t5.data.preprocessors.random_spans_noise_mask,
160
+ mean_noise_span_length=noise_span_length),
161
+ input_feature_key=input_feature_key)
162
+ if optional_task_prefixes:
163
+ ds = prepend_prompt(
164
+ ds,
165
+ output_features,
166
+ prompt_mode=optional_task_prefixes[i],
167
+ mode=optional_task_prefixes[i],
168
+ key=input_feature_key)
169
+ processed_ds.append(ds)
170
+ if use_prefix_lm_task:
171
+ ds = ds_shards[-1]
172
+ ds = prefix_lm(
173
+ ds, sequence_lengths, output_features)
174
+ if optional_task_prefixes:
175
+ ds = prepend_prompt(
176
+ ds,
177
+ output_features,
178
+ prompt_mode=optional_task_prefixes[-1],
179
+ mode=optional_task_prefixes[-1],
180
+ key=input_feature_key)
181
+ processed_ds.append(ds)
182
+ ds = tf.data.experimental.sample_from_datasets(processed_ds, rates, seed)
183
+ return ds