pere commited on
Commit
0a8986a
1 Parent(s): cb76020
__pycache__/tasks.cpython-38.pyc CHANGED
Binary files a/__pycache__/tasks.cpython-38.pyc and b/__pycache__/tasks.cpython-38.pyc differ
 
longt5/__pycache__/preprocessors.cpython-38.pyc ADDED
Binary file (4.8 kB). View file
 
longt5/preprocessors.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The LongT5 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2022 Google LLC.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+
29
+ """Preprocessors for long T5."""
30
+
31
+ from pegasus.data import parsers
32
+ import seqio
33
+ import t5.data
34
+ import tensorflow.compat.v2 as tf
35
+
36
+
37
+ def _string_join(lst):
38
+ # Join on space, but collapse consecutive spaces.
39
+ out = tf.strings.join(lst, separator=' ')
40
+ return tf.strings.regex_replace(out, r'\s+', ' ')
41
+
42
+
43
+ def _normalize_text(text):
44
+ """Lowercase and remove quotes from a TensorFlow string."""
45
+ text = tf.strings.lower(text)
46
+ text = tf.strings.regex_replace(text, "'(.*)'", r'\1')
47
+ return text
48
+
49
+
50
+ @seqio.map_over_dataset
51
+ def nq(x):
52
+ """Convert NQ TF examples to a text2text pair.
53
+
54
+ NQ produces examples with this form:
55
+ {'id_': <id>, 'title': <title>, context': <article>, 'question': <question>,
56
+ 'answer': <answer> }
57
+ This function will return examples of the format:
58
+ {'inputs': 'question: <question> context: <article>',
59
+ 'targets': '<answer>',
60
+ 'id': <id>, 'question': <question>, 'context': <context>,
61
+ 'answers': [<n answers>]},
62
+
63
+ Args:
64
+ x: an example to process.
65
+
66
+ Returns:
67
+ A preprocessed example with the format listed above.
68
+ """
69
+ inputs = _string_join(['question:', x['question'], 'context:', x['context']])
70
+
71
+ return {
72
+ 'inputs': inputs,
73
+ 'targets': x['answer'],
74
+ 'id': x['id_'],
75
+ 'context': x['context'],
76
+ 'question': x['question'],
77
+ 'answers': [x['answer']]
78
+ }
79
+
80
+
81
+ @seqio.map_over_dataset
82
+ def triviaqa(x, ignore_web=True, include_title=True):
83
+ """Convert TriviaQA TF examples to a text2text pair.
84
+
85
+ TriviaQA produces examples with this form:
86
+ {'entity_pages': {dict of wiki entities},
87
+ 'search_results': <dict of web search results>,
88
+ 'answer': {dict of all answers}, 'question': <question>,
89
+ 'question_id': <question_id>, 'question_source': <question_source>}
90
+
91
+ This function will return examples of the format:
92
+ {'inputs': 'question: <question> context: <article>',
93
+ 'targets': '<answer>',
94
+ 'id': <id>, 'question': <question>, 'context': <context>,
95
+ 'answers': [<n answers>]},
96
+
97
+ Args:
98
+ x: an example to process.
99
+ ignore_web: whether to ignore the web context
100
+ include_title: whether to include the title
101
+
102
+ Returns:
103
+ A preprocessed example with the format listed above.
104
+ """
105
+
106
+ question = _normalize_text(x['question'])
107
+
108
+ wiki_context = [_normalize_text(x['entity_pages']['wiki_context'])]
109
+ if include_title:
110
+ # Append the title before each context.
111
+ wiki_context = [_normalize_text(x['entity_pages']['title'])] + wiki_context
112
+ wiki_context = tf.transpose(tf.stack(wiki_context))
113
+ wiki_context = tf.strings.reduce_join(wiki_context, separator=' ')
114
+ context = wiki_context
115
+
116
+ if not ignore_web:
117
+ web_context = [_normalize_text(x['search_results']['search_context'])]
118
+ if include_title:
119
+ # Append the title before each context.
120
+ web_context = [_normalize_text(x['search_results']['title'])
121
+ ] + web_context
122
+ web_context = tf.transpose(tf.stack(web_context))
123
+ web_context = tf.strings.reduce_join(web_context, separator=' ')
124
+ context = _string_join([wiki_context, web_context])
125
+
126
+ inputs = _string_join(['question:', question, 'context:', context])
127
+ targets = _normalize_text(x['answer']['value'])
128
+
129
+ return {
130
+ 'inputs': inputs,
131
+ 'targets': targets,
132
+ 'id': x['question_id'],
133
+ 'context': context,
134
+ 'question': question,
135
+ 'answers': x['answer']['aliases']
136
+ }
137
+
138
+
139
+ # Preprocessor for PEGASUS type pretraining.
140
+ # Sentences/words are masked/replaced with different strategies. Details at
141
+ # https://arxiv.org/abs/1912.08777
142
+ pegasus_parser, _ = parsers.string_features_for_pretraining_parser(
143
+ vocab_filename='gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model',
144
+ encoder_type='sentencepiece_noshift', # Matches tokenizer used by T5.
145
+ max_input_len=4096,
146
+ max_target_len=910,
147
+ max_total_words=0,
148
+ parser_strategy='dynamic_rouge',
149
+ parser_masked_sentence_ratio=0.2,
150
+ parser_masked_words_ratio=0,
151
+ parser_mask_word_option_prob=[0.8, 0.1, 0.1],
152
+ parser_mask_sentence_option_prob=[.9, 0, .1, 0],
153
+ parser_rouge_ngrams_size=1,
154
+ parser_rouge_metric_type='F',
155
+ parser_rouge_compute_option='standard',
156
+ # The stopwords file used is here: https://gist.github.com/sebleier/554280
157
+ parser_rouge_stopwords_filename='',
158
+ shift_special_token_id=t5.data.DEFAULT_EXTRA_IDS - 2, # 2's for eos and pad
159
+ mode='',
160
+ parser_rouge_noise_ratio=.2,
161
+ parser_dynamic_mask_min_ratio=.33,
162
+ input_feature='inputs',
163
+ pretrain_target_filter_min=0)
164
+
165
+
166
+ @seqio.map_over_dataset
167
+ def pegasus_parse(x):
168
+ """Parses an example with the Pegasus parser.
169
+
170
+ As input, method receives:
171
+ {
172
+ 'inputs': '<sent1> <sent2> .... <sentn>'
173
+ 'targets': None
174
+ }
175
+ This function will return examples of the format:
176
+ {
177
+ 'inputs': '<sent1> <mask> .... <sentn>'
178
+ 'targets': '<sent2>'
179
+ }
180
+ though the returned example will have been tokenized with SPM and will
181
+ contain EOS id at the end of both inputs and targets (as is also done in T5).
182
+
183
+ Args:
184
+ x: an example to process.
185
+
186
+ Returns:
187
+ A preprocessed example, where some of the input is masked and copied to the
188
+ target. These values will have been tokenized with SPM.
189
+ """
190
+
191
+ # Add key 'supervised' as required by Pegasus parser.
192
+ x['supervised'] = tf.constant(False, dtype=tf.bool)
193
+ # Parse the input. Pegasus parser will return with some of the input masked
194
+ # and copied to target (all having been tokenized).
195
+ parsed = pegasus_parser(x)
196
+ # Adjust outputs from Pegasus parser to work with T5. This involves taking
197
+ # the elements at index 0 (to get the right shape needed) and casting from
198
+ # int64 to int32.
199
+ return {
200
+ 'inputs': tf.cast(parsed['inputs'][0], tf.int32),
201
+ 'targets': tf.cast(parsed['targets'][0], tf.int32)
202
+ }
longt5_1_1_base.gin ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LongT5 Base model. Config based on T5.1.1 Base model.
2
+ # Provides MODEL
3
+ from __gin__ import dynamic_registration
4
+
5
+ import seqio
6
+ from t5x import adafactor
7
+ from t5x import models
8
+ import tasks
9
+
10
+ ARCHITECTURE = %gin.REQUIRED
11
+
12
+ include 'flaxformer/t5x/configs/longt5/architectures/longt5_1_1_flaxformer.gin'
13
+
14
+ include 't5x/configs/runs/pretrain.gin'
15
+ #include 'pretrain_cont.gin'
16
+
17
+ MIXTURE_OR_TASK_NAME = "ncc_scandinavian_span_corruption_stream"
18
+ TASK_FEATURE_LENGTHS = {"inputs": 4048, "targets": 910}
19
+ # CORRECT IS 128!!
20
+ BATCH_SIZE=32
21
+ TRAIN_STEPS = 1_000_000
22
+ DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
23
+ #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_base/checkpoint_1500000"
24
+ #PjitPartitioner.num_partitions = 1
25
+
26
+
27
+ # Architecture overrides
28
+ NUM_HEADS = 12
29
+ NUM_ENCODER_LAYERS = 12
30
+ NUM_DECODER_LAYERS = 12
31
+ HEAD_DIM = 64
32
+ EMBED_DIM = 768
33
+ MLP_DIM = 2048
34
+
35
+ # Loss HParam defaults
36
+ Z_LOSS = 0.0001
37
+ LABEL_SMOOTHING = 0.0
38
+ LOSS_NORMALIZING_FACTOR = None
39
+
40
+ # Vocabulary (shared by encoder and decoder)
41
+ VOCABULARY = @seqio.SentencePieceVocabulary()
42
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"
43
+ NUM_EMBEDDINGS = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency
44
+
45
+ # Optimizer
46
+ # `learning_rate` is set by `Trainer.learning_rate_fn`.
47
+ OPTIMIZER = @adafactor.Adafactor()
48
+ adafactor.Adafactor:
49
+ decay_rate = 0.8
50
+ step_offset = 0
51
+
52
+ # Model
53
+ MODEL = @models.EncoderDecoderModel()
54
+ models.EncoderDecoderModel:
55
+ module = %ARCHITECTURE # provided by longt5_flaxformer
56
+ input_vocabulary = %VOCABULARY
57
+ output_vocabulary = %VOCABULARY
58
+ optimizer_def = %OPTIMIZER
59
+ z_loss = %Z_LOSS
60
+ label_smoothing = %LABEL_SMOOTHING
61
+ loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
tasks.py CHANGED
@@ -1,5 +1,4 @@
1
  import functools
2
-
3
  import seqio
4
  import tensorflow as tf
5
  import t5.data
@@ -10,9 +9,7 @@ from t5.evaluation import metrics
10
  from seqio import FunctionDataSource, utils
11
 
12
  TaskRegistry = seqio.TaskRegistry
13
-
14
- vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
15
- byt5_vocabulary = t5.data.ByteVocabulary()
16
 
17
  DEFAULT_OUTPUT_FEATURES = {
18
  "inputs": seqio.Feature(
@@ -22,14 +19,6 @@ DEFAULT_OUTPUT_FEATURES = {
22
  vocabulary=vocabulary, add_eos=True)
23
  }
24
 
25
- BYT5_DEFAULT_OUTPUT_FEATURES = {
26
- "inputs": seqio.Feature(
27
- vocabulary=byt5_vocabulary, add_eos=True,
28
- required=False),
29
- "targets": seqio.Feature(
30
- vocabulary=byt5_vocabulary, add_eos=True)
31
- }
32
-
33
 
34
  def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
35
  dataset = load_dataset(**dataset_params)
 
1
  import functools
 
2
  import seqio
3
  import tensorflow as tf
4
  import t5.data
 
9
  from seqio import FunctionDataSource, utils
10
 
11
  TaskRegistry = seqio.TaskRegistry
12
+ vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model', extra_ids=0)
 
 
13
 
14
  DEFAULT_OUTPUT_FEATURES = {
15
  "inputs": seqio.Feature(
 
19
  vocabulary=vocabulary, add_eos=True)
20
  }
21
 
 
 
 
 
 
 
 
 
22
 
23
  def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
24
  dataset = load_dataset(**dataset_params)
train_long_base.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/long-t5x"
2
+ T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ MODEL_DIR="gs://nb-t5x-us-central2/long_test_t5x_base"
4
+ export PYTHONPATH=${PROJECT_DIR}
5
+
6
+ python3 ${T5X_DIR}/t5x/train.py \
7
+ --gin_search_paths=${PROJECT_DIR} \
8
+ --gin_file="longt5_1_1_base.gin" \
9
+ --gin.MODEL_DIR="'${MODEL_DIR}'" \