|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Preprocessors for T5 Tasks.""" |
|
|
|
|
|
import collections |
|
import functools |
|
import math |
|
import re |
|
from typing import Callable, Mapping, Optional, Sequence, Union |
|
import uuid |
|
|
|
from absl import logging |
|
import babel |
|
import gin |
|
import seqio |
|
import tensorflow.compat.v2 as tf |
|
|
|
|
|
|
|
|
|
AUTOTUNE = tf.data.experimental.AUTOTUNE |
|
|
|
FeatureType = Mapping[str, tf.Tensor] |
|
|
|
rekey = seqio.preprocessors.rekey |
|
tokenize = seqio.preprocessors.tokenize |
|
|
|
|
|
@seqio.map_over_dataset |
|
def translate(x, source_language, target_language): |
|
"""Convert a translation dataset to a text2text pair. |
|
|
|
For example, say the dataset returns examples of this format: |
|
{'de': 'Das ist gut.', 'en': 'That is good.'} |
|
If source_language = 'de', target_language = 'en', then the outputs will have |
|
the format: |
|
{'inputs': 'translate German to English: Das ist gut.', |
|
'targets': 'That is good.'} |
|
|
|
Args: |
|
x: an example to process. |
|
source_language: source language code (e.g. 'en') to translate from. |
|
target_language: target language code (e.g. 'de') to translate to. |
|
|
|
Returns: |
|
A preprocessed example with the format listed above. |
|
""" |
|
|
|
for language in (source_language, target_language): |
|
if language != language[:2]: |
|
logging.warning( |
|
'Extended language code %s not supported. Falling back on %s.', |
|
language, language[:2] |
|
) |
|
lang_id_to_string = { |
|
source_language: babel.Locale(source_language[:2]).english_name, |
|
target_language: babel.Locale(target_language[:2]).english_name, |
|
} |
|
src_str = 'translate {}'.format(lang_id_to_string[source_language]) |
|
tgt_str = ' to {}: '.format(lang_id_to_string[target_language]) |
|
return { |
|
'inputs': tf.strings.join([src_str, tgt_str, x[source_language]]), |
|
'targets': x[target_language], |
|
} |
|
|
|
|
|
@seqio.map_over_dataset |
|
def summarize(x, article_key, summary_key): |
|
"""Convert a summarization dataset to a text2text pair. |
|
|
|
For example, say the dataset returns examples of this format: |
|
{'article': <article>, 'highlights': <summary>} |
|
If article_key = 'article', summary_key = 'highlights', then the outputs will |
|
have the format: |
|
{'inputs': 'summarize': <article>, 'targets': <summary>} |
|
|
|
Args: |
|
x: an example to process. |
|
article_key: the feature key for the article to summarize. |
|
summary_key: the feature key for the target summary. |
|
Returns: |
|
A preprocessed example with the format listed above. |
|
""" |
|
strs_to_join = ['summarize:', x[article_key]] |
|
return { |
|
'inputs': tf.strings.join(strs_to_join, separator=' '), |
|
'targets': x[summary_key], |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
NON_SPACED_LANGUAGE_RANGES = ( |
|
'\u1000-\u104f', |
|
'\u4e00-\u9fff', |
|
'\u3400-\u4dbf', |
|
'\uf900-\ufaff', |
|
'\u2e80-\u2eff', |
|
'\u31c0-\u31ef', |
|
'\u3000-\u303f', |
|
'\u3040-\u309f', |
|
'\u30a0-\u30ff', |
|
'\ua980-\ua9df', |
|
'\u1780-\u17ff', |
|
'\u19e0-\u19ff', |
|
'\u0e80-\u0eff', |
|
'\u1980-\u19df', |
|
'\u1a20-\u1aaf', |
|
'\u0e00-\u0e7f', |
|
'\u0f00-\u0fff', |
|
) |
|
|
|
|
|
@seqio.map_over_dataset |
|
def pad_nonspaced_languages(x, text_key='text'): |
|
"""Pad non-spaced languages with spaces around each character. |
|
|
|
Args: |
|
x: an example to process. |
|
text_key: a string, the key for the text feature to preprocess in the |
|
dataset examples. |
|
|
|
Returns: |
|
A preprocessed example. |
|
""" |
|
res = dict(x) |
|
text = res[text_key] |
|
|
|
pattern = ''.join(NON_SPACED_LANGUAGE_RANGES) |
|
text = tf.strings.regex_replace(text, u'([{}])'.format(pattern), r' \1 ') |
|
|
|
text = tf.strings.regex_replace(text, r'\s+', ' ') |
|
res[text_key] = text |
|
return res |
|
|
|
|
|
def _pad_punctuation(text): |
|
"""Adds spaces around punctuation.""" |
|
|
|
text = tf.strings.regex_replace(text, r'([[:punct:]])', r' \1 ') |
|
|
|
text = tf.strings.regex_replace(text, r'\s+', ' ') |
|
return text |
|
|
|
|
|
def _string_join(lst): |
|
|
|
out = tf.strings.join(lst, separator=' ') |
|
return tf.strings.regex_replace(out, r'\s+', ' ') |
|
|
|
|
|
def trivia_qa(dataset): |
|
"""Convert a TriviaQA example to multiple flattened examples. |
|
|
|
TriviaQA produces examples with this form: |
|
{'entity_pages': {dict of wiki entities}, |
|
'search_results': <dict of web search results>, |
|
'answer': {dict of all answers}, 'question': <question>, |
|
'question_id': <question_id>, 'question_source': <question_source>} |
|
This function will return flattend examples of the format: |
|
{'inputs': 'question: <question> context: <article>' |
|
'targets': 'answer: <sampled answer>'} |
|
|
|
Args: |
|
dataset: a tf.data.Dataset to process. |
|
Returns: |
|
A preprocessed tf.data.Dataset with the format listed above. |
|
""" |
|
def triviaqa_question_answer_context(x): |
|
"""Extracts matched contexts and answers. |
|
|
|
Returns all matched (question-context, answer) pairs. |
|
|
|
Args: |
|
x: A tfds sample. |
|
|
|
Returns: |
|
Flattened samples: (question-context, answer). |
|
""" |
|
contexts = [] |
|
if 'entity_pages' in x: |
|
contexts.append(x['entity_pages']['wiki_context']) |
|
if 'search_results' in x: |
|
contexts.append(x['search_results']['search_context']) |
|
contexts = tf.concat(contexts, 0) |
|
|
|
q = _pad_punctuation(x['question']) |
|
answers = x['answer']['normalized_aliases'] |
|
|
|
combination_size = tf.size(answers)*tf.size(contexts) |
|
find_answers = tf.TensorArray( |
|
tf.bool, size=combination_size, dynamic_size=True) |
|
selected_answers = tf.TensorArray( |
|
tf.string, size=combination_size, dynamic_size=True) |
|
join_q_c = tf.TensorArray( |
|
tf.string, size=combination_size, dynamic_size=True) |
|
|
|
def cond_fn(i, find_answers, selected_answers, join_q_c): |
|
del find_answers, selected_answers, join_q_c |
|
return tf.less(i, combination_size) |
|
|
|
def body_fn(i, find_answers, selected_answers, join_q_c): |
|
"""Find answers from contexts and join.""" |
|
context_idx = tf.math.floordiv(i, tf.size(answers)) |
|
answer_idx = tf.math.mod(i, tf.size(answers)) |
|
|
|
a = _pad_punctuation(answers[answer_idx]) |
|
a_ = tf.strings.join(['.*', a, '.*']) |
|
c = _pad_punctuation(contexts[context_idx]) |
|
find_a = tf.strings.regex_full_match( |
|
tf.strings.lower(c), |
|
tf.strings.lower(a_)) |
|
find_answers = find_answers.write(i, find_a) |
|
selected_answers = selected_answers.write(i, a) |
|
|
|
join_q_c_str = _string_join(['question:', q, 'context:', c]) |
|
join_q_c = join_q_c.write(i, join_q_c_str) |
|
return (i + 1, find_answers, selected_answers, join_q_c) |
|
|
|
_, find_answers, selected_answers, join_q_c = tf.while_loop( |
|
cond_fn, |
|
body_fn, |
|
loop_vars=[ |
|
tf.constant(0), find_answers, selected_answers, |
|
join_q_c |
|
]) |
|
find_answers = find_answers.stack() |
|
selected_answers = selected_answers.stack() |
|
join_q_c = join_q_c.stack() |
|
|
|
selected_answers = tf.boolean_mask(selected_answers, find_answers) |
|
selected_join_q_c = tf.boolean_mask(join_q_c, find_answers) |
|
|
|
return selected_join_q_c, selected_answers |
|
|
|
def my_fn(x): |
|
"""Create TriviaQA example.""" |
|
join_q_c, a = triviaqa_question_answer_context(x) |
|
return { |
|
'inputs': join_q_c, |
|
'targets': a |
|
} |
|
|
|
dataset = dataset.map(my_fn, num_parallel_calls=AUTOTUNE) |
|
return dataset.unbatch() |
|
|
|
|
|
@seqio.map_over_dataset |
|
def squad(x, include_context=True): |
|
"""Convert SQuAD examples to a text2text pair. |
|
|
|
SQuAD produces examples with this form: |
|
{'id': <id>, context': <article>, 'question': <question>, |
|
'answers': { 'text': [<n answers>] }} |
|
This function will return examples of the format: |
|
{'inputs': 'question: <question> context: <article>', |
|
'targets': '<answer_0>', |
|
'id': <id>, 'question': <question>, 'context': <context>, |
|
'answers': [<n answers>]}, |
|
|
|
Args: |
|
x: an example to process. |
|
include_context: a boolean |
|
Returns: |
|
A preprocessed example with the format listed above. |
|
""" |
|
a = _pad_punctuation(x['answers']['text']) |
|
q = _pad_punctuation(x['question']) |
|
c = _pad_punctuation(x['context']) |
|
if include_context: |
|
inputs = _string_join(['question:', q, 'context:', c]) |
|
else: |
|
inputs = _string_join(['squad trivia question:', q]) |
|
return { |
|
'inputs': inputs, |
|
'targets': a[0], |
|
'id': x['id'], |
|
'context': c, |
|
'question': q, |
|
'answers': a |
|
} |
|
|
|
|
|
def _span_answer(context, answer_text): |
|
"""Finds start/end indices of answer_text in context after space tokenization. |
|
|
|
If answer_tokens is not a sublist of context_tokens, returns empty string. |
|
|
|
Args: |
|
context: 0-d string tensor |
|
answer_text: 0-d string |
|
|
|
Returns: |
|
A string tensor. |
|
""" |
|
def space_tok(s): |
|
"""Replace non-word chars with space then split on space.""" |
|
s = tf.strings.regex_replace(s, r'\W', ' ') |
|
return tf.strings.split(input=[s], sep=' ').values |
|
|
|
def find_subseq(n, h): |
|
"""Finds index of needle subsequence inside haystack. |
|
|
|
Args: |
|
n: 1-d tensor |
|
h: 1-d tensor same type as n |
|
|
|
Returns: |
|
Index of start of n if found found; otherwise -1. |
|
""" |
|
l_n = tf.size(n) |
|
l_h = tf.size(h) |
|
found = -1 |
|
for i in tf.range(0, l_h - l_n): |
|
if tf.reduce_all(tf.equal(h[i:i+l_n], n)): |
|
found = i |
|
break |
|
return found |
|
|
|
answer_tokens = space_tok(answer_text) |
|
context_tokens = space_tok(context) |
|
start = find_subseq(answer_tokens, context_tokens) |
|
end = start + tf.size(answer_tokens) - 1 |
|
|
|
if tf.equal(start, -1): |
|
return '' |
|
return tf.strings.format('start: {} end: {}', [start, end]) |
|
|
|
|
|
def squad_span_space_tokenized(dataset): |
|
"""Convert SQuAD examples to a text2text pair with span output. |
|
|
|
SQuAD produces examples with this form: |
|
{'context': <article>, 'question': <question>, |
|
'answers': { 'text': [<all answers>] }} |
|
|
|
This function returns examples with the format |
|
{'inputs': 'context: <article> question: <question>', |
|
'targets': 'start: <start_index> end: <end_index>'} |
|
where <start_index> and <end_index> specify the space-tokenized span |
|
start/end indices. Both <start_index> and <end_index> are included in |
|
the answer. In the case where the tokenized answer is |
|
not found in the tokenized context, the example is skipped. |
|
|
|
Args: |
|
dataset: a tf.data.Dataset to process. |
|
Returns: |
|
A preprocessed tf.data.Dataset with the format listed above. |
|
""" |
|
def my_fn(x): |
|
"""Create squad example as in squad_span_char, but tokenized on spaces.""" |
|
res = dict(x) |
|
res['targets'] = _span_answer(x['context'], x['targets']) |
|
return res |
|
|
|
dataset = squad(dataset) |
|
dataset = dataset.map(my_fn, num_parallel_calls=AUTOTUNE) |
|
return dataset.filter(lambda x: tf.strings.length(x['targets']) > 0) |
|
|
|
|
|
def random_split_text(dataset, |
|
text_key='text', |
|
min_words_per_segment=16, |
|
max_words_per_segment=512, |
|
max_words_total=8192): |
|
"""Randomly split single-string examples into multiple examples each. |
|
|
|
Segment lengths are chosen according to a log-uniform distribution. |
|
Each incoming string is chopped into multiple equal-length examples |
|
with the last one possibly being shorter. |
|
|
|
If the input string is longer than max_words_total, then we use one random |
|
chunk and discard the rest. This may help with model stability. |
|
|
|
The intended use case is to break up long text examples for use in |
|
unsupervised transfer-learning. |
|
|
|
We don't really want to use this preprocessor for any dataset which has a |
|
well-defined evaluation procedure. If apply this preprocessor e.g. in an MT |
|
component, then the evaluation job will randomly split text when evaluating |
|
and the BLEU will get funky. |
|
|
|
Args: |
|
dataset: a tf.data.Dataset with dictionaries containing the key text_key |
|
text_key: a string |
|
min_words_per_segment: an integer |
|
max_words_per_segment: an integer |
|
max_words_total: an integer |
|
|
|
Returns: |
|
a dataset |
|
""" |
|
def random_chunk(x, chunk_size, seed): |
|
"""Pick a random chunk of a 1d Tensor. |
|
|
|
The tensor is divided into chunks of length chunk_size, with the last |
|
chunk being potentially smaller. A random chunk is returned. |
|
|
|
Args: |
|
x: a 1d tf.Tensor. |
|
chunk_size: an integer. |
|
seed: int32 [2]-Tensor, the random seed. |
|
Returns: |
|
a 1d tf.Tensor with length <= chunk_size. |
|
""" |
|
size = tf.size(x) |
|
num_chunks = tf.maximum(1, (size - 1) // chunk_size + 1) |
|
chunk_num = tf.random.stateless_uniform( |
|
[], |
|
seed=seed, |
|
minval=0, |
|
maxval=num_chunks, |
|
dtype=tf.int32) |
|
return x[chunk_size * chunk_num:chunk_size * (chunk_num + 1)] |
|
|
|
@seqio.map_over_dataset(num_seeds=2) |
|
def my_fn(x, seeds): |
|
"""Split one string into multiple strings. |
|
|
|
Args: |
|
x: a feature dictionary |
|
seeds: an int32 Tensor, shaped (2, 2), the random seeds. |
|
Returns: |
|
a feature dictionary |
|
""" |
|
text = x[text_key] |
|
words = tf.strings.split([text]).values |
|
if max_words_total: |
|
words = random_chunk(words, max_words_total, seed=seeds[0]) |
|
n_words = tf.size(words) |
|
|
|
length = tf.cast( |
|
tf.exp( |
|
tf.random.stateless_uniform( |
|
[], |
|
minval=math.log(min_words_per_segment), |
|
maxval=math.log(max_words_per_segment), |
|
seed=seeds[1], |
|
) |
|
), |
|
tf.int32) |
|
|
|
|
|
num_segments = tf.cast( |
|
tf.math.ceil( |
|
tf.cast(n_words, tf.float32) / tf.cast(length, tf.float32) |
|
), |
|
tf.int32) |
|
padding = num_segments * length - n_words |
|
words = tf.pad(words, [[0, padding]]) |
|
words = tf.reshape(words, [-1, length]) |
|
|
|
|
|
words = tf.strings.reduce_join(words, axis=1, separator=' ') |
|
return {text_key: tf.strings.strip(words)} |
|
|
|
return my_fn(dataset).unbatch() |
|
|
|
|
|
def split_text_to_words(dataset, text_key='text', min_num_words=2): |
|
"""Split text to words and filter out examples with too few words.""" |
|
def split(x): |
|
res = dict(x) |
|
res['words'] = tf.strings.split([x[text_key]]).values |
|
return res |
|
|
|
dataset = dataset.map(split, num_parallel_calls=AUTOTUNE) |
|
return dataset.filter(lambda x: tf.size(x['words']) >= min_num_words) |
|
|
|
|
|
def fill_in_the_blank(dataset, |
|
text_key='text', |
|
label='fill: '): |
|
"""Create a dataset consisting of fill-in-the-blank text examples. |
|
|
|
The input examples should have a key text_key associated with a tf.string |
|
value. |
|
|
|
The output examples have keys 'inputs' and 'targets'. |
|
|
|
The input string is split on whitespace to form a sequence of words. |
|
This sequence is chopped randomly into segments of one or more words. |
|
Alternate segments are included in the inputs and targets, with a special |
|
word 'X' marking a missing segment. |
|
|
|
The given label is prepended to the inputs. Each input string produces two |
|
examples - one the inverse of the other. Inputs with less than two words |
|
are dropped. |
|
|
|
EXAMPLE: |
|
|
|
input: |
|
{ |
|
'text': 'The fat cat sat on the mat.' |
|
} |
|
outputs: |
|
{ |
|
'inputs': 'fill: The fat X the X' |
|
'targets': 'X cat sat on X mat.' |
|
} |
|
{ |
|
'inputs': 'fill: X cat sat on X mat.' |
|
'targets': 'The fat X the X' |
|
} |
|
|
|
Args: |
|
dataset: a tf.data.Dataset |
|
text_key: a string, the key for the text feature to preprocess in the |
|
dataset examples. |
|
label: a string, the label to prepend to the inputs. |
|
Returns: |
|
a tf.data.Dataset |
|
""" |
|
@seqio.map_over_dataset(num_seeds=3) |
|
def my_fn(x, seeds): |
|
"""Generates two preprocessed examples that are roughly inverses. |
|
|
|
Args: |
|
x: an example dict with text pre-split in `words` feature. |
|
seeds: an int32 Tensor, shaped (3, 2), the random seeds. |
|
Returns: |
|
an example dict with two inputs and two targets, one for each resulting |
|
preprocessed example. |
|
""" |
|
words = x['words'] |
|
n_words = tf.size(words) |
|
|
|
|
|
|
|
|
|
min_log_p_break = -tf.math.log(tf.cast(n_words, tf.float32) + 2.0) |
|
max_log_p_break = -tf.math.log(2.0) |
|
p_break = tf.exp( |
|
tf.random.stateless_uniform( |
|
[], |
|
minval=min_log_p_break, |
|
maxval=max_log_p_break, |
|
seed=seeds[0]) |
|
) |
|
|
|
|
|
|
|
breaks = tf.less( |
|
tf.random.stateless_uniform([n_words - 1], seed=seeds[1]), |
|
p_break) |
|
def one_random_break(): |
|
pos = tf.random.stateless_uniform( |
|
[], |
|
minval=0, |
|
maxval=n_words - 1, |
|
dtype=tf.int32, |
|
seed=seeds[2]) |
|
return tf.one_hot(pos, n_words - 1, |
|
dtype=tf.bool, on_value=True, off_value=False) |
|
breaks = tf.cond( |
|
tf.math.reduce_any(breaks), lambda: breaks, one_random_break) |
|
breaks = tf.concat([[True], breaks], axis=0) |
|
word_to_seq_id = tf.math.mod(tf.math.cumsum(tf.cast(breaks, tf.int32)), 2) |
|
|
|
|
|
|
|
|
|
results = [] |
|
for seq_id in [0, 1]: |
|
in_my_seq = tf.equal(word_to_seq_id, seq_id) |
|
separator_strings = tf.where( |
|
in_my_seq, |
|
' ', |
|
tf.where(breaks, ' X', '') |
|
) |
|
word_strings = tf.where(in_my_seq, words, '') |
|
all_strings = tf.stack([separator_strings, word_strings], axis=1) |
|
results.append(tf.strings.substr( |
|
tf.strings.reduce_join(all_strings), 1, tf.int32.max)) |
|
inputs = tf.stack([tf.strings.join([label, results[0]]), |
|
tf.strings.join([label, results[1]])]) |
|
targets = tf.stack([results[1], results[0]]) |
|
return {'inputs': inputs, 'targets': targets} |
|
dataset = split_text_to_words(dataset, text_key, min_num_words=2) |
|
return my_fn(dataset).unbatch() |
|
|
|
|
|
def fill_in_the_blank_sized( |
|
dataset, |
|
size_bins=(1, 2, 4, 8, 16, 32, 64, 128, 256, 512), |
|
text_key='text', |
|
label='fill: '): |
|
"""Fill in the blank preprocessor that labels blank with a binned size. |
|
|
|
The actual blank size is sampled uniformly from the inclusive range of the min |
|
and max bin. The blank is then filled in with the closest bin size to the |
|
actual blank size. |
|
|
|
Args: |
|
dataset: a tf.data.Dataset, the dataset to preprocess. |
|
size_bins: a list, a list of blank sizes to select from when labelling the |
|
blank. |
|
text_key: a string, the key for the text feature to preprocess in the |
|
dataset examples. |
|
label: a string, the label to prepend to the inputs. |
|
|
|
Returns: |
|
a tf.data.Dataset |
|
""" |
|
bins = sorted(size_bins) |
|
|
|
@seqio.map_over_dataset(num_seeds=2) |
|
def my_fn(x, seeds): |
|
"""Apply transformation.""" |
|
words = x['words'] |
|
n_words = tf.size(words) |
|
|
|
blank_size = tf.random.stateless_uniform( |
|
[], |
|
minval=bins[0], |
|
maxval=tf.math.minimum(n_words, bins[-1]), |
|
dtype=tf.dtypes.int32, |
|
seed=seeds[0]) |
|
bin_delta = tf.math.abs(bins - blank_size) |
|
bin_ = tf.gather(bins, tf.argmin(bin_delta)) |
|
blank_start = tf.random.stateless_uniform( |
|
[], |
|
minval=0, |
|
maxval=tf.math.maximum(0, n_words-blank_size) + 1, |
|
dtype=tf.dtypes.int32, |
|
seed=seeds[1]) |
|
|
|
pre_blank = tf.strings.reduce_join(words[0:blank_start], separator=' ') |
|
post_blank = tf.strings.reduce_join( |
|
words[blank_start+blank_size:], separator=' ') |
|
blank = tf.strings.format('_{}_', bin_) |
|
|
|
input_ = tf.strings.strip( |
|
tf.strings.join([pre_blank, blank, post_blank], ' ')) |
|
input_ = tf.strings.join([label, input_]) |
|
target = tf.strings.reduce_join( |
|
words[blank_start:blank_start+blank_size], separator=' ') |
|
return { |
|
'inputs': tf.strings.strip(input_), |
|
'targets': tf.strings.strip(target)} |
|
dataset = split_text_to_words(dataset, text_key, min_num_words=2) |
|
|
|
dataset = dataset.filter(lambda x: tf.size(x['words']) >= bins[0]) |
|
return my_fn(dataset) |
|
|
|
|
|
def neighboring_pairs(dataset, text_key='text', reuse_sentences=True): |
|
"""Create a dataset consisting of neighboring sentence pairs. |
|
|
|
The input examples should have a key text_key associated with a tf.string |
|
value. |
|
|
|
The output examples have keys 'first' and 'second'. |
|
|
|
We only take sentence pairs from within the same line since lines seem to |
|
represent paragraph-like structures in our text datasets. Empty lines and |
|
1-sentence lines will thus be ignored. |
|
|
|
The argument reuse_sentences determines whether a sentence can be used as both |
|
the first and last element in the pair. For example, the input with sentences |
|
A,B,C,D will return (A,B),(B,C),(C,D) if reuse_sentences is True and |
|
(A,B),(C,D) if reuse_sentences is False. |
|
|
|
Args: |
|
dataset: a tf.data.Dataset |
|
text_key: a string, the key for the text feature to preprocess in the |
|
dataset examples. |
|
reuse_sentences: a boolean |
|
|
|
Returns: |
|
a tf.data.Dataset |
|
""" |
|
|
|
def split_by_lines(dataset): |
|
"""Splits text in dataset by line, removing empty lines.""" |
|
def my_fn(text): |
|
lines = tf.strings.split([text], sep='\n').values |
|
return tf.strings.strip(lines) |
|
|
|
dataset = dataset.map(my_fn, num_parallel_calls=AUTOTUNE) |
|
dataset = dataset.unbatch() |
|
return dataset.filter(lambda x: tf.strings.length(x) > 0) |
|
|
|
def split_into_pairs(line): |
|
"""Split a given text example into pairs of neighboring sentences.""" |
|
|
|
sep = str(uuid.uuid4()) |
|
sentences = tf.strings.regex_replace(line, r'((?:\.|\!|\?)+)', r'\1' + sep) |
|
sentences = tf.strings.strip(tf.strings.split([sentences], sep).values) |
|
if reuse_sentences: |
|
firsts = sentences[:-1] |
|
seconds = sentences[1:] |
|
else: |
|
firsts = sentences[:-1:2] |
|
seconds = sentences[1::2] |
|
return { |
|
'first': firsts, |
|
'second': seconds, |
|
} |
|
|
|
def example_len(x): |
|
return tf.math.minimum( |
|
tf.strings.length(x['first']), tf.strings.length(x['second'])) |
|
|
|
|
|
dataset = dataset.map(lambda x: x[text_key], num_parallel_calls=AUTOTUNE) |
|
dataset = split_by_lines(dataset) |
|
|
|
|
|
dataset = dataset.map(split_into_pairs, num_parallel_calls=AUTOTUNE) |
|
dataset = dataset.unbatch() |
|
|
|
|
|
dataset = dataset.filter(lambda x: example_len(x) > 0) |
|
return dataset |
|
|
|
|
|
@seqio.map_over_dataset |
|
def glue(x, benchmark_name, label_names, feature_names=None, id_key='idx'): |
|
"""Convert a dataset from glue to text2text examples. |
|
|
|
This function uses the feature names from the dataset to unpack examples into |
|
a format amenable for a text2text problem. For example, consider the Quora |
|
Question Pairs (QQP) benchmark, which would suggest |
|
benchmark_name="qqp" |
|
label_names=['not_duplicate', 'duplicate'] |
|
For QQP, a typical example might look like |
|
{ |
|
"question1": "Why do I easily get bored of my friends?", |
|
"question2": "Why do I get bored of friends so quickly?", |
|
"label": 1, |
|
"idx": 10, |
|
} |
|
|
|
This example would be transformed to |
|
{ |
|
"inputs": ( |
|
"qqp question1: Why do I easily get bored of my friends? question2: " |
|
"Why do I get bored of my friends so quickly?" |
|
), |
|
"targets": "duplicate", |
|
"idx": 10, |
|
} |
|
|
|
Args: |
|
x: an example to process. |
|
benchmark_name: the name of the GLUE benchmark for this dataset. |
|
label_names: a list of label names corresponding to class index. |
|
feature_names: an optional ordered list of feature names. If provided, |
|
features will be ordered in this way in the output. If not provided, all |
|
features (except 'idx' and 'label') will be used, sorted by name. |
|
id_key: str, key for id in the dataset. If not provided, 'idx' will be used. |
|
if None, no id will be added to the dataset. |
|
|
|
Returns: |
|
A preprocessed example. |
|
""" |
|
|
|
|
|
feature_keys = ( |
|
feature_names or sorted(set(x.keys()).difference(['label', 'idx']))) |
|
|
|
strs_to_join = [] |
|
for key in feature_keys: |
|
strs_to_join.append('{}:'.format(key)) |
|
strs_to_join.append(x[key]) |
|
|
|
strs_to_join.insert(0, benchmark_name) |
|
label_name = tf.cond( |
|
|
|
tf.equal(x['label'], -1), |
|
lambda: tf.constant('<unk>'), |
|
|
|
lambda: tf.gather(label_names, x['label']), |
|
) |
|
joined = tf.strings.join(strs_to_join, separator=' ') |
|
|
|
ex = {} |
|
|
|
if benchmark_name == 'multirc': |
|
|
|
joined = tf.strings.regex_replace(joined, '<br>', ' ') |
|
joined = tf.strings.regex_replace(joined, '<(/)?b>', '') |
|
|
|
|
|
ex['idx/paragraph'] = x['idx']['paragraph'] |
|
ex['idx/question'] = x['idx']['question'] |
|
ex['idx/answer'] = x['idx']['answer'] |
|
else: |
|
|
|
if id_key: |
|
ex['idx'] = x[id_key] |
|
|
|
ex['inputs'] = joined |
|
ex['targets'] = label_name |
|
|
|
return ex |
|
|
|
|
|
@seqio.map_over_dataset |
|
def stsb(x): |
|
"""Convert STSB examples to text2text format. |
|
|
|
STSB maps two sentences to a floating point number between 1 and 5 |
|
representing their semantic similarity. Since we are treating all tasks as |
|
text-to-text tasks we need to convert this floating point number to a string. |
|
The vast majority of the similarity score labels in STSB are in the set |
|
[0, 0.2, 0.4, ..., 4.8, 5.0]. So, we first round the number to the closest |
|
entry in this set, and then we convert the result to a string (literally e.g. |
|
"3.4"). This converts STSB roughly into a 26-class classification dataset. |
|
This function uses the feature names from the dataset to unpack examples into |
|
a format amenable for a text2text problem. |
|
|
|
For example, a typical example from STSB might look like |
|
{ |
|
"sentence1": "Three more US soldiers killed in Afghanistan", |
|
"sentence2": "NATO Soldier Killed in Afghanistan", |
|
"label": 1.8, |
|
} |
|
|
|
This example would be transformed to |
|
{ |
|
"inputs": ( |
|
"stsb sentence1: Three more US soldiers killed in Afghanistan " |
|
"sentence2: NATO Soldier Killed in Afghanistan" |
|
), |
|
"targets": "1.8", |
|
} |
|
|
|
Args: |
|
x: an example to process. |
|
Returns: |
|
A preprocessed example. |
|
""" |
|
strs_to_join = [ |
|
'stsb sentence1:', x['sentence1'], 'sentence2:', x['sentence2'] |
|
] |
|
label_string = tf.as_string(tf.round(x['label'] * 5) / 5, precision=1) |
|
joined = tf.strings.join(strs_to_join, separator=' ') |
|
return {'inputs': joined, 'targets': label_string, 'idx': x['idx']} |
|
|
|
|
|
@seqio.map_over_dataset |
|
def wsc(x): |
|
"""Convert WSC examples to text2text format. |
|
|
|
WSC includes a sentence along with 2 'spans': the first denoting a noun and |
|
the other a pronoun. The 'label' specifies whether or not the pronoun is |
|
referencing the noun. This preprocessor puts ' * ' around the noun and ' # ' |
|
around the pronoun. |
|
|
|
For example, a typical example from WSC might look like |
|
{ |
|
'text': 'This is a test sentence .', |
|
'span1_text': 'test', |
|
'span1_index': 3, |
|
'span2_text': 'This', |
|
'span2_index': 0, |
|
'label': 0 |
|
} |
|
|
|
This example would be transformed to |
|
{ |
|
'inputs': 'wsc text: # This # is a * test * sentence .', |
|
'targets': 'False' |
|
} |
|
|
|
Args: |
|
x: an example to process. |
|
Returns: |
|
A preprocessed example. |
|
""" |
|
|
|
def _mark_span(text, span_str, span_idx, mark): |
|
pattern_tmpl = r'^((?:\S+\s){N})(W)' |
|
pattern = tf.strings.regex_replace(pattern_tmpl, 'N', |
|
tf.as_string(span_idx)) |
|
pattern = tf.strings.regex_replace(pattern, 'W', span_str) |
|
return tf.strings.regex_replace(text, pattern, r'\1{0} \2 {0}'.format(mark)) |
|
|
|
text = x['text'] |
|
text = _mark_span(text, x['span1_text'], x['span1_index'], '*') |
|
|
|
span2_index = x['span2_index'] + 2 * tf.cast( |
|
x['span1_index'] < x['span2_index'], tf.int32) |
|
text = _mark_span(text, x['span2_text'], span2_index, '#') |
|
|
|
|
|
strs_to_join = ['wsc', 'text:', text] |
|
label_name = tf.cond( |
|
|
|
tf.equal(x['label'], -1), |
|
lambda: tf.constant('<unk>'), |
|
|
|
lambda: tf.gather(['False', 'True'], x['label'])) |
|
|
|
joined = tf.strings.join(strs_to_join, separator=' ') |
|
return {'inputs': joined, 'targets': label_name, 'idx': x['idx']} |
|
|
|
|
|
@gin.configurable |
|
def record(dataset): |
|
"""Convert ReCoRD examples to text2text examples. |
|
|
|
ReCoRD contains a passage, query containing a '@placeholder' string, and a set |
|
of entities that are the possible values of the placeholder. Each train and |
|
validation example will have a list of answers, any of which would be |
|
considered correct. |
|
|
|
For example, a typical example from ReCoRD might look like |
|
{ |
|
'passsage': 'This is the passage.', |
|
'query': 'A @placeholder is a bird.', |
|
'entities': ['penguin', 'potato', 'pigeon'], |
|
'answers': ['penguin', 'pigeon'], |
|
} |
|
which this preprocessor would turn into the following two examples: |
|
{ |
|
'inputs': 'record query: A @placeholder is a bird. entities: penguin, ' |
|
'potato, pigeon passage: This is the passage.', |
|
'targets': 'penguin', |
|
} |
|
and |
|
{ |
|
'inputs': 'record query: A @placeholder is a bird. entities: penguin, ' |
|
'potato, pigeon passage: This is the passage.', |
|
'targets': 'potato', |
|
} |
|
|
|
Args: |
|
dataset: a tf.data.Dataset to process. |
|
|
|
Returns: |
|
a tf.data.Dataset |
|
""" |
|
|
|
def process_answers(x): |
|
"""Helper fn to get one example per answer.""" |
|
ex = x.copy() |
|
num_answers = tf.size(ex['answers']) |
|
|
|
def duplicate_along_first_dim(t): |
|
n_duplicates = tf.math.maximum(num_answers, 1) |
|
return tf.broadcast_to( |
|
t, shape=tf.concat([[n_duplicates], tf.shape(t)], axis=0)) |
|
|
|
for k, v in x.items(): |
|
if k != 'idx': |
|
ex[k] = duplicate_along_first_dim(v) |
|
ex['targets'] = tf.cond( |
|
tf.greater(num_answers, 0), lambda: x['answers'], |
|
lambda: tf.constant(['<unk>'])) |
|
ex['idx'] = { |
|
'passage': duplicate_along_first_dim(x['idx']['passage']), |
|
'query': duplicate_along_first_dim(x['idx']['query']), |
|
} |
|
|
|
return ex |
|
|
|
def my_fn(x): |
|
"""Converts the processed example to text2text strings.""" |
|
passage = x['passage'] |
|
passage = tf.strings.regex_replace(passage, |
|
r'(\.|\?|\!|\"|\')\n@highlight\n', |
|
r'\1 ') |
|
passage = tf.strings.regex_replace(passage, r'\n@highlight\n', '. ') |
|
|
|
strs_to_join = [ |
|
'record query:', x['query'], 'entities:', |
|
tf.strings.reduce_join(x['entities'], separator=', '), 'passage:', |
|
passage |
|
] |
|
joined = tf.strings.join(strs_to_join, separator=' ') |
|
|
|
ex = {} |
|
|
|
|
|
ex['idx/passage'] = x['idx']['passage'] |
|
ex['idx/query'] = x['idx']['query'] |
|
|
|
ex['inputs'] = joined |
|
|
|
|
|
ex['targets'] = x['targets'] |
|
|
|
ex['answers'] = x['answers'] |
|
return ex |
|
|
|
dataset = dataset.map(process_answers, num_parallel_calls=AUTOTUNE) |
|
dataset = dataset.unbatch() |
|
return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) |
|
|
|
|
|
def multi_translate(dataset, source_language, target_language): |
|
"""Convert a multi-translate dataset to a text2text pair. |
|
|
|
For example, say the dataset returns examples which have a 'translations' |
|
feature key so that examples have the following format: |
|
{ |
|
... |
|
'translations': { |
|
'language': ['de', 'fr', 'en'], |
|
'translation': ['Das ist gut.', 'Ca c'est bon', 'That is good.'] |
|
}, |
|
... |
|
} |
|
If source_language = 'de', target_language = 'en', then this function will |
|
return examples of the format: |
|
{'inputs': 'translate German to English: Das is gut.', |
|
'targets': 'That is good.'} |
|
Any other languages present in the dataset will be filtered out. |
|
|
|
Args: |
|
dataset: a tf.data.Dataset to process. |
|
source_language: source language code (e.g. 'en') to translate from. |
|
target_language: target language code (e.g. 'de') to translate to. |
|
Returns: |
|
A preprocessed tf.data.Dataset with the format listed above. |
|
""" |
|
def filter_fn(x): |
|
langs = x['translations']['language'] |
|
|
|
source_in_langs = tf.reduce_any(tf.equal(source_language, langs)) |
|
target_in_langs = tf.reduce_any(tf.equal(target_language, langs)) |
|
return tf.logical_and(source_in_langs, target_in_langs) |
|
def map_fn(x): |
|
langs = x['translations']['language'] |
|
|
|
src_idx = tf.squeeze(tf.where(tf.equal(langs, source_language))) |
|
tgt_idx = tf.squeeze(tf.where(tf.equal(langs, target_language))) |
|
return { |
|
source_language: x['translations']['translation'][src_idx], |
|
target_language: x['translations']['translation'][tgt_idx], |
|
} |
|
dataset = dataset.filter(filter_fn) |
|
dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) |
|
return translate(dataset, source_language, target_language) |
|
|
|
|
|
@seqio.map_over_dataset |
|
def definite_pronoun_resolution_simple(x, label='wsc:'): |
|
"""Converts DPR examples to a simple text to text format. |
|
|
|
A typical example from the definite pronoun resolution dataset might look like |
|
{ |
|
'sentence': 'Bob asked Tom if he can lend some money.', |
|
'pronoun': 'he', |
|
'candidates': ['Bob', 'Tom'], |
|
'label': 1, |
|
} |
|
|
|
This will be transformed to |
|
{ |
|
'inputs': 'wsc: Bob asked Tom if *he* can lend some money.' |
|
'targets': 'Tom', |
|
} |
|
|
|
Args: |
|
x: an example to process. |
|
label: a string, the label to prepend to the inputs. |
|
|
|
Returns: |
|
A preprocessed example. |
|
""" |
|
|
|
|
|
inputs = [ |
|
label, |
|
tf.strings.regex_replace( |
|
x['sentence'], |
|
tf.strings.join([r' (', x['pronoun'], r')( |\.|,)']), |
|
r' *\1*\2', |
|
replace_global=False, |
|
), |
|
] |
|
return { |
|
'inputs': tf.strings.join(inputs, separator=' '), |
|
'targets': x['candidates'][x['label']], |
|
} |
|
|
|
|
|
def next_sentence_prediction(dataset, |
|
text_key='text', |
|
reuse_sentences=True, |
|
label_sentences=False, |
|
p_neighbors=0.5, |
|
label='nsp: ', |
|
buffer_size=50000): |
|
"""Create a dataset containing a next sentence prediction objective. |
|
|
|
The input examples should have a key text_key associated with a tf.string |
|
value. |
|
|
|
The output examples have keys 'inputs' and 'targets'. |
|
|
|
EXAMPLE OUTPUTS: |
|
|
|
{ |
|
input: "nsp: sentence1: The man went to the store. sentence2: Penguins are " |
|
"flightless birds.", |
|
target: "not_next" |
|
} |
|
|
|
The "sentence1:" and "sentence2:" labels will be omitted if label_sentences is |
|
False. |
|
|
|
Args: |
|
dataset: a tf.data.Dataset |
|
text_key: a string, the key for the text feature to preprocess in the |
|
dataset examples. |
|
reuse_sentences: a boolean, see docs for `neighboring_pairs` for more info. |
|
label_sentences: a boolean |
|
p_neighbors: a float between 0 and 1, the probability that a sentence pair |
|
will be neighbors. |
|
label: a string, the label to prepend to the inputs. |
|
buffer_size: an int, the size of the shuffle buffer used to get |
|
non-neighboring sentences. |
|
|
|
Returns: |
|
a tf.data.Dataset |
|
""" |
|
sentence1_label, sentence2_label = '', '' |
|
if label_sentences: |
|
sentence1_label, sentence2_label = 'sentence1: ', 'sentence2: ' |
|
|
|
empty = tf.constant('', dtype=tf.string, shape=[1]) |
|
|
|
dataset = neighboring_pairs( |
|
dataset, text_key=text_key, reuse_sentences=reuse_sentences) |
|
dataset = dataset.shuffle(buffer_size).batch(2, drop_remainder=True) |
|
|
|
def some_are_empty(*tensors): |
|
"""See if at least one tensor has shape [0].""" |
|
empty = [tf.equal(tf.size(t), 0) for t in tensors] |
|
return tf.reduce_any(empty) |
|
|
|
@seqio.map_over_dataset(num_seeds=1) |
|
def my_fn(x, seed): |
|
"""Function to be applied to each example in dataset.""" |
|
use_neighbors = ( |
|
tf.random.stateless_uniform(shape=[], seed=seed) < p_neighbors |
|
) |
|
firsts, seconds = tf.cond( |
|
use_neighbors, |
|
lambda: (x['first'], x['second']), |
|
lambda: (x['first'], tf.stack([x['second'][1], x['second'][0]])), |
|
) |
|
relation_label = tf.cond( |
|
use_neighbors, |
|
lambda: 'next', |
|
lambda: 'not_next', |
|
) |
|
|
|
inputs = [] |
|
for i in range(2): |
|
first_inputs = firsts[i] |
|
second_inputs = seconds[i] |
|
|
|
def create_examples(first_i=first_inputs, second_i=second_inputs): |
|
return tf.strings.join([ |
|
label, |
|
sentence1_label, |
|
first_i, |
|
' ', |
|
sentence2_label, |
|
second_i, |
|
]) |
|
|
|
inpt = tf.cond( |
|
some_are_empty(first_inputs, second_inputs), |
|
lambda: empty, |
|
create_examples, |
|
) |
|
inputs.append(tf.strings.strip(inpt)) |
|
inputs = tf.reshape(inputs, [-1]) |
|
targets = tf.reshape(2 * [relation_label], [-1]) |
|
|
|
return {'inputs': inputs, 'targets': targets} |
|
|
|
dataset = my_fn(dataset).unbatch() |
|
|
|
def example_len(x): |
|
return tf.math.minimum( |
|
tf.strings.length(x['inputs']), tf.strings.length(x['targets'])) |
|
|
|
|
|
return dataset.filter(lambda x: example_len(x) > 0) |
|
|
|
|
|
@seqio.map_over_dataset |
|
def lm(x): |
|
"""Basic language modeling objective for text - empty inputs. |
|
|
|
Given inputs with the format: |
|
{"text": "Here is some text."} |
|
This preprocess produces examples with the format |
|
{"inputs": "", "targets": "Here is some text."} |
|
|
|
Args: |
|
x: an example to process. |
|
|
|
Returns: |
|
A preprocessed example. |
|
""" |
|
return {'inputs': '', 'targets': x['text']} |
|
|
|
|
|
def _wsc_inputs(x): |
|
"""Given an example from SuperGLUE WSC, compute the 'inputs' value. |
|
|
|
The output will look like a fill in the blank with the pronoun blanked out. |
|
For example, the text |
|
'Mitchell asked Tom if he could lend some money.' |
|
would be transformed to |
|
'Mitchell asked Tom if X could lend some money.' |
|
|
|
Args: |
|
x: A dict that is an example from the WSC task of SuperGLUE. |
|
|
|
Returns: |
|
A scalar string tensor. |
|
""" |
|
words = tf.strings.split([x['text']], sep=' ').values |
|
|
|
|
|
|
|
|
|
with tf.control_dependencies([ |
|
tf.assert_greater(x['span2_index'], 0), |
|
tf.assert_less(x['span2_index'], tf.size(words)), |
|
]): |
|
pronoun_index = tf.identity(x['span2_index']) |
|
|
|
def create_input(): |
|
with tf.control_dependencies( |
|
[tf.assert_equal(words[pronoun_index], x['span2_text'])]): |
|
return tf.strings.join( |
|
[ |
|
tf.strings.reduce_join(words[:pronoun_index], separator=' '), |
|
'X', |
|
tf.strings.reduce_join( |
|
words[pronoun_index + 1:], separator=' '), |
|
], |
|
separator=' ', |
|
) |
|
|
|
|
|
if tf.equal( |
|
x['text'], |
|
'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. \"Good for him,\" he said. ' |
|
): |
|
return ( |
|
'The boy continued to whip the pony , and eventually the pony threw ' |
|
'him over. John laughed out quite loud. "Good for X ," he said.' |
|
) |
|
|
|
|
|
if tf.equal( |
|
x['text'], |
|
'When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?' |
|
): |
|
return ( |
|
'When they had eventually calmed down a bit , and had gotten home, ' |
|
'Mr. Farley put the magic pebble in an iron safe . Some day they might ' |
|
'want to use X , but really for now, what more could they wish for?' |
|
) |
|
|
|
return create_input() |
|
|
|
|
|
def wsc_simple(dataset, |
|
label='wsc:', |
|
correct_referent_only=False): |
|
"""Converts SuperGLUE WSC examples to a simple text to text format. |
|
|
|
A typical example from SuperGLUE WSC might look like |
|
{ |
|
'text': 'Mitchell asked Tom if he could lend some money.', |
|
'span1_text': 'Tom', |
|
'span2_text': 'he', |
|
'span2_index': 4, |
|
} |
|
|
|
This will be transformed to |
|
{ |
|
'inputs': 'wsc: Bob asked Tom if *he* can lend some money.' |
|
'targets': 'Tom', |
|
} |
|
|
|
The targets will always be the text of the referent regardless of whether it |
|
is the correct referrent of the pronoun. Thus for training purposes, please |
|
set `correct_referent_only` to be True. |
|
|
|
Args: |
|
dataset: a tf.data.Dataset |
|
label: a string, the label to prepend to the inputs. |
|
correct_referent_only: a bool, whether to filter out examples for which the |
|
targets is not the correct referent of the pronoun. |
|
|
|
Returns: |
|
a tf.data.Dataset |
|
""" |
|
|
|
def map_fn(x): |
|
"""Function to be called for every example in dataset.""" |
|
inputs = [ |
|
label, |
|
tf.strings.regex_replace( |
|
_wsc_inputs(x), r' X ', ' *' + x['span2_text'] + '* '), |
|
] |
|
referent = x['span1_text'] |
|
return { |
|
'inputs': tf.strings.join(inputs, separator=' '), |
|
|
|
'targets': tf.reshape(referent, shape=[]), |
|
'label': x.get('label', 0), |
|
'idx': x['idx'], |
|
} |
|
|
|
if correct_referent_only: |
|
dataset = dataset.filter(lambda x: tf.cast(x.get('label', False), tf.bool)) |
|
|
|
return dataset.map(map_fn, num_parallel_calls=AUTOTUNE) |
|
|
|
|
|
@seqio.map_over_dataset |
|
def wnli_simple(x, label='wsc:'): |
|
"""Converts GLUE WNLI examples to a simple text to text format. |
|
|
|
A typical example from WNLI might look like: |
|
{ |
|
'sentence1': 'The fish ate the worm. It was tasty.', |
|
'sentence2': 'The worm was tasty.', |
|
'label': 1, |
|
} |
|
|
|
This will be transformed to: |
|
{ |
|
'inputs': 'wsc: The fish ate the worm. *It* was tasty.', |
|
'targets': 'The worm', |
|
'premise': 'The fish ate the worm. It was tasty., |
|
'hypothesis': 'The worm was tasty.', |
|
'label': 1, |
|
} |
|
|
|
This preprocessor has been manually verified to produce reasonable WSC |
|
examples for the dev and test sets. Tasks using this preprocessor should only |
|
be used eval and not train. |
|
|
|
Args: |
|
x: an example to process. |
|
label: a string, the label to prepend to the inputs. |
|
|
|
Returns: |
|
A preprocessed example. |
|
""" |
|
pronouns = ['he', 'she', 'they', 'it', 'her', 'his', 'their', 'them', 'him'] |
|
PronounMatch = collections.namedtuple( |
|
'PronounMatch', ['score', 'index_in_premise', 'candidate']) |
|
|
|
def split_clean(s): |
|
"""Returns array of words with punctuation and capitalization removed.""" |
|
words = [ |
|
re.sub(r'(\.|,|\?|\!)$', '', w) for w in s.strip().lower().split(' ') |
|
] |
|
return [w for w in words if w] |
|
|
|
def get_all_pronoun_indices(s): |
|
return [i for i, w in enumerate(s) if w in pronouns] |
|
|
|
def get_post_match_size(hypothesis, words): |
|
"""Returns len of largest prefix of words that is substr of hypothesis.""" |
|
hypothesis = ' '.join(hypothesis) |
|
for i in range(len(words)): |
|
if ' '.join(words[:i + 1]) not in hypothesis: |
|
return i |
|
return len(words) |
|
|
|
def get_pre_match_size(hypothesis, words): |
|
"""Returns len of largest suffix of words that is substr of hypothesis.""" |
|
return get_post_match_size(hypothesis[::-1], words[::-1]) |
|
|
|
def get_pronoun_match(premise, hypothesis, index): |
|
"""Return the PronounMatch for the pronoun at `index` in premise.""" |
|
pre, post = premise[:index], premise[index + 1:] |
|
|
|
pre_match_size = get_pre_match_size(hypothesis, pre) |
|
post_match_size = get_post_match_size(hypothesis, post) |
|
score = pre_match_size + post_match_size |
|
|
|
candidate = '' |
|
if score: |
|
pre_match = pre[-pre_match_size or len(pre):] |
|
post_match = post[:post_match_size] |
|
m = re.search(' '.join(pre_match + [r'(.+)'] + post_match), |
|
' '.join(hypothesis)) |
|
if not m: |
|
|
|
m = re.search(' '.join([r'^(.+)'] + post_match), ' '.join(hypothesis)) |
|
if not m: |
|
|
|
m = re.search(' '.join(pre_match + [r'(.+)$']), ' '.join(hypothesis)) |
|
|
|
if m: |
|
candidate = m.group(1) |
|
|
|
return PronounMatch( |
|
score=score, index_in_premise=index, candidate=candidate) |
|
|
|
def get_best_pronoun_match(premise, hypothesis): |
|
"""Returns the match for the pronoun in the premise to disambiguate.""" |
|
pronoun_indices = get_all_pronoun_indices(premise) |
|
scoredpronouns = [ |
|
get_pronoun_match(premise, hypothesis, index) |
|
for index in pronoun_indices |
|
] |
|
return max(scoredpronouns, key=lambda x: x.score) |
|
|
|
def highlight(sentence, index): |
|
words = sentence.split(' ') |
|
word = words[index] |
|
if word[-1] in ['.', ',', '!', '?']: |
|
highlighted = '*{}* {}'.format(word[:-1], word[-1]) |
|
else: |
|
highlighted = '*{}*'.format(word) |
|
return ' '.join(words[:index] + [highlighted] + words[index + 1:]) |
|
|
|
def make_nonpossessive(word): |
|
|
|
|
|
if word.endswith("'"): |
|
return word[:-1] |
|
elif word.endswith("'s"): |
|
return word[:-2] |
|
else: |
|
return word |
|
|
|
def clean_up(candidate): |
|
words = candidate.split(' ') |
|
|
|
|
|
|
|
|
|
article_index = max( |
|
[words.index(art) for art in {'a', 'an', 'the'} if art in words] or [0]) |
|
return ' '.join(words[article_index:]) |
|
|
|
def process_candidate(candidate, hypothesis): |
|
"""Handles special cases and adds proper punctuation/capitalization.""" |
|
candidate = clean_up(candidate) |
|
|
|
pattern = '({})'.format(' '.join([ |
|
r'{}(?:\.|,|\?|\!)?'.format(re.escape(c)) for c in candidate.split(' ') |
|
])) |
|
m = re.search(pattern, hypothesis, re.IGNORECASE) |
|
if not m: |
|
raise ValueError( |
|
'Unable to find candidate "{}" in hypothesis "{}".'.format( |
|
candidate, hypothesis)) |
|
|
|
candidate = m.group(1) |
|
if candidate and candidate[-1] in ['.', ',', '!', '?']: |
|
candidate = candidate[:-1] |
|
return make_nonpossessive(candidate) |
|
|
|
def compute_inputs_and_targets(premise, hypothesis): |
|
"""Compute inputs and targets for WNLI simple.""" |
|
premise = tf.compat.as_text(premise.numpy()) |
|
hypothesis = tf.compat.as_text(hypothesis.numpy()) |
|
|
|
match = get_best_pronoun_match( |
|
split_clean(premise), split_clean(hypothesis)) |
|
targets = process_candidate(match.candidate, hypothesis) |
|
inputs = '{} {}'.format(label, highlight(premise, match.index_in_premise)) |
|
return inputs, targets |
|
|
|
inputs, targets = tf.py_function( |
|
compute_inputs_and_targets, |
|
inp=[x['sentence1'], x['sentence2']], |
|
Tout=[tf.string, tf.string]) |
|
return { |
|
|
|
'inputs': tf.reshape(inputs, shape=[]), |
|
'targets': tf.reshape(targets, shape=[]), |
|
'premise': x['sentence1'], |
|
'hypothesis': x['sentence2'], |
|
'label': x.get('label', 0), |
|
'idx': x['idx'], |
|
} |
|
|
|
|
|
def rank_classification( |
|
ds: tf.data.Dataset, |
|
inputs_fn: Callable[[FeatureType], tf.Tensor], |
|
targets_fn: Callable[[FeatureType], tf.Tensor], |
|
is_correct_fn: Callable[[FeatureType], tf.Tensor], |
|
weight_fn: Optional[Callable[[FeatureType], tf.Tensor]] = None, |
|
mode: str = 'eval', |
|
passthrough_feature_keys: Optional[Sequence[str]] = None, |
|
) -> tf.data.Dataset: |
|
"""Prepare dataset for rank classification scoring. |
|
|
|
Intended to be used with `rank_classification` postprocessor and metric. |
|
|
|
`inputs_fn` and `targets_fn` must return the 'inputs' and 'targets' features, |
|
respectively, for each possible class label given the raw example features. |
|
'is_correct_fn' must return the 'is_correct' feature, a boolean for whether |
|
each label is correct. |
|
|
|
In 'train' mode, only the inputs / targets marked correct will be produced. |
|
In 'eval' mode, all inputs / targets will be produced. |
|
In 'fewshot_eval', all inputs / targets will be produced as a single batch. |
|
|
|
Each output example will also be given a unique 'idx' feature. The first dim |
|
is a sequential index for the input example and the second is the index of the |
|
generated output for it. E.g., the second output example from the fourth input |
|
example would be `[3, 1]`. |
|
|
|
To be clear, consider the following arguments: |
|
|
|
inputs_fn=lambda ex: ex['prefix'], |
|
targets_fn=lambda ex: ex['suffix'], |
|
is_correct_fn=lambda ex: tf.one_hot(ex['label'], num_classes) |
|
weight_fn=lambda ex: ex['weight'] |
|
|
|
Given the following example: |
|
|
|
{ |
|
'prefix': ['The farmland needed ', 'The farmland wanted '], |
|
'suffix': ['water', 'cows'], |
|
'label': 0, |
|
'weight': 1.0, |
|
} |
|
the preprocessor would return: |
|
|
|
[{ |
|
'idx': [0, 0], |
|
'inputs': 'The farmland needed ', |
|
'targets': 'water', |
|
'is_correct': True, |
|
'weight': 1.0 |
|
}, |
|
{ |
|
'idx': [0, 1], |
|
'inputs': 'The farmland wanted ', |
|
'targets': 'cows', |
|
'is_correct': False, |
|
'weight': 1.0 |
|
}] |
|
|
|
With mode set to 'train', it would return only the first example, |
|
since it uses the correct label. With mode set to 'fewshot_eval', it would |
|
return both examples in a single batch. |
|
|
|
Args: |
|
ds: a tf.data.Dataset to preprocess. |
|
inputs_fn: a callable that returns the 'inputs' features for each label |
|
given the input example. |
|
targets_fn: a callable that returns the 'targets' features for each label |
|
given the input example. |
|
is_correct_fn: a callable that returns the 'label' feature. May be an int32 |
|
scalar or 1-D Tensor. |
|
weight_fn: a callable that returns the 'weight' feature (float32 scalar). |
|
mode: A string, one of 'train' or'eval 'train' produces only the correct |
|
example(s) based on the label value(s). 'eval' produces an example for |
|
every possible class value, sequentially. 'fewshot_eval' produces an |
|
example for every possible class value, batched together for each input |
|
example. |
|
passthrough_feature_keys: a sequence of feature names that should be passed |
|
through to the output of this preprocessor. eg: ["starburst", "tokens"] |
|
|
|
Returns: |
|
A tf.data.Dataset containing 'idx', inputs', 'targets', and 'is_correct'. |
|
""" |
|
if mode not in ('train', 'eval', 'fewshot_eval'): |
|
raise ValueError( |
|
"Mode must be one of 'train', 'eval', or 'fewshot_eval'. " |
|
f"Got '{mode}'.") |
|
|
|
def make_examples(idx, ex): |
|
inputs = inputs_fn(ex) |
|
targets = targets_fn(ex) |
|
is_correct = tf.cast(is_correct_fn(ex), tf.bool) |
|
|
|
tf.debugging.assert_equal( |
|
tf.size(is_correct), [tf.size(inputs), tf.size(targets)], |
|
'`inputs_fn`, `targets_fn`, and `is_correct_fn` must return the same ' |
|
'size tensors.') |
|
|
|
num_out = tf.size(is_correct) |
|
in_idx = tf.fill([num_out], tf.cast(idx, tf.int32)) |
|
out_idx = tf.range(num_out) |
|
|
|
output = { |
|
'idx': tf.stack([in_idx, out_idx], 1), |
|
'inputs': inputs, |
|
'targets': targets, |
|
'is_correct': is_correct, |
|
} |
|
|
|
if passthrough_feature_keys is not None: |
|
for feature_name in passthrough_feature_keys: |
|
output[feature_name] = [ex[feature_name]] * len(targets) |
|
|
|
if weight_fn is not None: |
|
output['weight'] = tf.fill(tf.shape(is_correct), weight_fn(ex)) |
|
output['weight'] = tf.cast(output['weight'], tf.float32) |
|
|
|
return output |
|
|
|
ds = ds.enumerate() |
|
ds = ds.map(make_examples, num_parallel_calls=AUTOTUNE) |
|
if mode != 'fewshot_eval': |
|
ds = ds.unbatch() |
|
if mode == 'train': |
|
ds = ds.filter(lambda ex: ex['is_correct']) |
|
return ds |
|
|
|
|
|
def rank_classification_formatter( |
|
ds: tf.data.Dataset, |
|
inputs_formats: Union[str, Sequence[str]], |
|
targets_formats: Union[str, Sequence[str]], |
|
mode: str = 'eval', |
|
label_key: str = 'label', |
|
weight_key: Optional[str] = None) -> tf.data.Dataset: |
|
"""Create 'inputs' and 'targets' strings for ranking classification. |
|
|
|
Intended to be used with `rank_classification` postprocessor and metric. |
|
|
|
Inputs will be formatted by filling in the feature values in the |
|
`inputs_formats` and `targets_formats` strings. |
|
|
|
Nested features can be accessed by concatenating the features using forward |
|
slash. For eg: if sub-sub-key is nested under sub-key, which is nested under |
|
key, then sub-sub-key can be accessed using key/sub-key/sub-sub-key. |
|
|
|
In 'eval' mode, a separate example will be produced for each targets / inputs |
|
format string. These can then be scored to find the one with the highest |
|
likelihood. The `rank_classification` postprocessor and metric allow you to |
|
evaluate with this technique. |
|
|
|
In 'train' mode, only the targets / inputs format string indexed by the |
|
label(s) will be produced. In 'eval' mode, all inputs / targets will be |
|
produced. |
|
|
|
Each input example will also be given a unique, sequential index called 'idx'. |
|
|
|
For example, with arguments: |
|
|
|
``` |
|
inputs_format='{premise} What is the {question}? X', |
|
targets_formats=[ |
|
'I think {choice1}.', |
|
'I think {choice2}.' |
|
], |
|
mode='eval' |
|
``` |
|
|
|
given the input: |
|
|
|
{ |
|
'premise': 'The farmland needed irrigation.', |
|
'question': 'effect', |
|
'choice1' : 'a canal was constructed', |
|
'choice2': 'the crops grew tall', |
|
'label': 0, |
|
} |
|
|
|
the preprocessor would return: |
|
[{ |
|
'idx': 0, |
|
'inputs': 'The farmland needed irrigation. What is the effect? X', |
|
'targets': 'I think a canal was constructed.', |
|
'is_correct': True |
|
}, |
|
{ |
|
'idx': 0, |
|
'inputs': 'The farmland needed irrigation. What is the effect? X', |
|
'targets': 'I think the crops grew tall.', |
|
'is_correct': False |
|
}] |
|
|
|
With `mode='train'`, it would return only the first example, |
|
since it uses the correct label. |
|
|
|
With `mode='fewshot_eval'`, it would return both examples in a single batch. |
|
|
|
Args: |
|
ds: a tf.data.Dataset to preprocess. |
|
inputs_formats: A string or a list of strings to format with feature values |
|
to produce 'inputs'. Feature keys should be surrounded by curly braces to |
|
be replaced. |
|
targets_formats: A string or a list of strings to format with feature values |
|
to produce 'targets', one for each possible class value. Feature keys |
|
should be surrounded by curly braces to be replaced. |
|
mode: A string, one of 'train', 'eval', or 'fewshot_train') 'train' produces |
|
only the correct example(s) based on the label value(s). 'eval' produces |
|
an example for every possible class value, sequentially. |
|
'fewshot_eval': produces an example for every possible class value, |
|
batched together for each input example. |
|
label_key: A string, the feature key for the integer label value(s). |
|
weight_key: A string, the feature key for the float example weight. |
|
|
|
Returns: |
|
A tf.data.Dataset containing 'idx', inputs', 'targets', and 'is_correct'. |
|
""" |
|
if (isinstance(inputs_formats, (list, tuple)) and |
|
isinstance(targets_formats, (list, tuple))): |
|
if len(inputs_formats) != len(targets_formats): |
|
raise ValueError( |
|
f'The inputs_formats ({len(inputs_formats)}) and ' |
|
f'targets_formats ({len(targets_formats)}) are both instances ' |
|
'of list or tuple, but do not have matching lengths.') |
|
elif isinstance(inputs_formats, (list, tuple)): |
|
num_classes = len(inputs_formats) |
|
targets_formats = [targets_formats] * num_classes |
|
elif isinstance(targets_formats, (list, tuple)): |
|
num_classes = len(targets_formats) |
|
inputs_formats = [inputs_formats] * num_classes |
|
else: |
|
raise ValueError( |
|
'One of the inputs_formats and targets_formats has to ' |
|
f'be a list or tuple, inputs_formats: {inputs_formats}, ' |
|
f'target_formats: {targets_formats}.') |
|
|
|
def _format_str(features, fmt): |
|
keys = set(re.findall(r'{(\S+)}', fmt)) |
|
s = fmt |
|
for k in keys: |
|
value = features |
|
for subkey in k.split('/'): |
|
value = value[subkey] |
|
if not isinstance(value, tf.Tensor): |
|
raise ValueError( |
|
f'Final value of key \'{k}\' must be a tf.string. ' |
|
f'Got: {type(value).__name__}') |
|
tf.debugging.assert_type( |
|
value, tf.string, |
|
f'Final value of key \'{k}\' must be a tf.string. ' |
|
f'Got: {value.dtype.name}') |
|
s = tf.strings.regex_replace(s, '{%s}' % k, value) |
|
return s |
|
|
|
def _apply_formats(features, fmts): |
|
return [_format_str(features, fmt) for fmt in fmts] |
|
|
|
def _is_correct_fn(ex): |
|
labels = ex[label_key] |
|
is_correct = tf.one_hot(labels, num_classes, on_value=True, off_value=False) |
|
if labels.shape.rank: |
|
is_correct = tf.math.reduce_any(is_correct, axis=0) |
|
return is_correct |
|
|
|
def _weight_fn(ex): |
|
return ex[weight_key] |
|
|
|
return rank_classification( |
|
ds, |
|
inputs_fn=functools.partial(_apply_formats, fmts=inputs_formats), |
|
targets_fn=functools.partial(_apply_formats, fmts=targets_formats), |
|
is_correct_fn=_is_correct_fn, |
|
weight_fn=None if weight_key is None else _weight_fn, |
|
mode=mode) |
|
|
|
|
|
@seqio.map_over_dataset |
|
def parse_tsv(line, field_names=None, field_delim='\t'): |
|
"""Splits TSV lines into dict examples mapping field name to string value. |
|
|
|
Args: |
|
line: an example containing a comma/tab-delimited string. |
|
field_names: a list of strings, the ordered names of the TSV fields. |
|
Defaults to "inputs" and "targets". |
|
field_delim: a string, the delimiter to split on e.g. ',' for csv. |
|
|
|
Returns: |
|
A feature dict mapping field name to string value. |
|
""" |
|
field_names = field_names or ['inputs', 'targets'] |
|
return dict( |
|
zip(field_names, |
|
tf.io.decode_csv( |
|
line, |
|
record_defaults=[''] * len(field_names), |
|
field_delim=field_delim, |
|
use_quote_delim=False))) |
|
|
|
|
|
@seqio.map_over_dataset |
|
def preprocess_tsv(line, |
|
field_delim='\t', |
|
num_fields=2, |
|
inputs_format='{0}', |
|
targets_format='{1}', |
|
field_names=None): |
|
r"""Parse tab-delimited strings into inputs and targets. |
|
|
|
This function takes a tf.data.Dataset of strings, each of which contains |
|
tab-delimited fields. The function returns a tf.data.Dataset of feature |
|
dictionaries of the form {"inputs": string, "targets": string}. |
|
|
|
inputs_format contains a template string and field numbers or names used to |
|
produce the "inputs" string. |
|
targets_format contains a template string and field numbers or names used to |
|
produce the "targets" string. |
|
|
|
Example (field numbers): |
|
The input dataset contains the lines: |
|
"6,7,42" |
|
"2,9,18" |
|
preprocess_tsv(dataset, |
|
field_delim=',', |
|
inputs_format='numerator: {2} denominator: {1}', |
|
targets_format='quotient: {0}' |
|
would produce a dataset containing the dictionaries: |
|
{"inputs": "numerator: 42 denomnator: 7", "targets": "quotient: 6"} |
|
{"inputs": "numerator: 18 denomnator: 9", "targets": "quotient: 2"} |
|
|
|
Example (field names): |
|
The input dataset contains the lines: |
|
"6,7,42" |
|
"2,9,18" |
|
preprocess_tsv(dataset, |
|
field_delim=',', |
|
field_names=['quot', 'denom', 'numer'], |
|
inputs_format='numerator: {numer} denominator: {denom}', |
|
targets_format='quotient: {quot}' |
|
would produce a dataset containing the dictionaries: |
|
{"inputs": "numerator: 42 denominator: 7", "targets": "quotient: 6"} |
|
{"inputs": "numerator: 18 denominator: 9", "targets": "quotient: 2"} |
|
|
|
Args: |
|
line: an example containing comma/tab-delimited string. |
|
field_delim: a string, the delimiter to split on e.g. ',' for csv. |
|
num_fields: an integer |
|
inputs_format: a string, the desired output format with placeholders for |
|
field values. |
|
targets_format: a string, the desired output format with placeholders for |
|
field values. |
|
field_names: a list of strings, the ordered names of the TSV fields. |
|
defaults to None (i.e. use field number in *_format) |
|
Returns: |
|
A feature dict with 'inputs' and 'targets' features. |
|
""" |
|
def _format_part_with_field_numbers(part, field_values): |
|
found = re.findall(r'{(\d+)}', part) |
|
if found: |
|
return field_values[int(found[0])] |
|
else: |
|
return part |
|
|
|
def _format_part_with_field_names(part, field_names, field_values): |
|
field_names_re = '|'.join(['{{({})}}'.format(x) for x in field_names]) |
|
found = re.findall(field_names_re, part) |
|
if found: |
|
pos = field_names.index(''.join(found[0])) |
|
return field_values[int(pos)] |
|
else: |
|
return part |
|
|
|
def _format(format_string, field_names, field_values): |
|
if field_names is None: |
|
parts = [ |
|
_format_part_with_field_numbers(p, field_values) |
|
for p in re.split(r'({\d+})', format_string) |
|
] |
|
else: |
|
field_names_re = '(' + '|'.join(['{{{}}}'.format(x) for x in field_names |
|
]) + ')' |
|
parts = [ |
|
_format_part_with_field_names(p, field_names, field_values) |
|
for p in re.split(field_names_re, format_string) |
|
] |
|
return tf.strings.join(parts) |
|
|
|
field_values = tf.io.decode_csv( |
|
line, |
|
record_defaults=[''] * |
|
(num_fields if field_names is None else len(field_names)), |
|
field_delim=field_delim, |
|
use_quote_delim=False) |
|
return { |
|
'inputs': _format(inputs_format, field_names, field_values), |
|
'targets': _format(targets_format, field_names, field_values) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
def span_corruption(dataset, |
|
sequence_length, |
|
output_features, |
|
mean_noise_span_length=3.0, |
|
noise_density=0.15, |
|
input_feature_key='inputs', |
|
merge_examples_to_reduce_padding=True, |
|
reserved_for_packing=None): |
|
"""Final pretraining objective used in Raffel et al., 2019. |
|
|
|
Args: |
|
dataset: A tf.data.Dataset with dictionaries containing the key |
|
`input_feature_key`. |
|
sequence_length: dict mapping of feature key to int length for that feature. |
|
output_features: mapping of keys to features. |
|
mean_noise_span_length: the mean number of tokens per masked span per |
|
example. |
|
noise_density: what fraction of the tokens to mask. |
|
input_feature_key: which feature to use from the dataset as the input text |
|
tokens. |
|
merge_examples_to_reduce_padding: if True, combines multiple input examples |
|
to reduce padding. |
|
reserved_for_packing: if specified, reduces the desired inputs length by the |
|
specified amount to enable multiple examples to be packed together |
|
downstream. |
|
|
|
Returns: |
|
a dataset |
|
""" |
|
inputs_length = sequence_length[input_feature_key] |
|
if reserved_for_packing: |
|
inputs_length -= reserved_for_packing |
|
|
|
input_length, targets_length = random_spans_helper( |
|
extra_tokens_per_span_inputs=1, |
|
extra_tokens_per_span_targets=1, |
|
inputs_length=inputs_length, |
|
mean_noise_span_length=mean_noise_span_length, |
|
noise_density=noise_density) |
|
|
|
if sequence_length['targets'] < targets_length: |
|
raise ValueError( |
|
f'Expected targets length for span corruption ({targets_length}) is ' |
|
f'greater than configured targets length ' |
|
f"({sequence_length['targets']})") |
|
|
|
ds = dataset |
|
ds = select_random_chunk( |
|
ds, |
|
output_features=output_features, |
|
feature_key='targets', |
|
max_length=65536) |
|
if merge_examples_to_reduce_padding: |
|
ds = reduce_concat_tokens(ds, feature_key='targets', batch_size=128) |
|
ds = split_tokens( |
|
ds, |
|
feature_key='targets', |
|
min_tokens_per_segment=None, |
|
max_tokens_per_segment=input_length) |
|
ds = denoise( |
|
ds, |
|
output_features, |
|
inputs_fn=noise_span_to_unique_sentinel, |
|
targets_fn=nonnoise_span_to_unique_sentinel, |
|
noise_density=noise_density, |
|
noise_mask_fn=functools.partial( |
|
random_spans_noise_mask, |
|
mean_noise_span_length=mean_noise_span_length), |
|
input_feature_key=input_feature_key) |
|
return ds |
|
|
|
|
|
|
|
def iid_denoising(dataset, sequence_length, output_features): |
|
"""Baseline pretraining objective used in Raffel et al., 2019.""" |
|
ds = dataset |
|
ds = select_random_chunk(ds, output_features=output_features, |
|
feature_key='targets', max_length=65536) |
|
ds = reduce_concat_tokens(ds, feature_key='targets', batch_size=128) |
|
ds = split_tokens_to_inputs_length(ds, output_features=output_features, |
|
sequence_length=sequence_length) |
|
ds = denoise( |
|
ds, |
|
output_features, |
|
inputs_fn=noise_span_to_unique_sentinel, |
|
targets_fn=nonnoise_span_to_unique_sentinel, |
|
noise_density=0.15, |
|
noise_mask_fn=iid_noise_mask |
|
) |
|
return ds |
|
|
|
|
|
def prefix_lm(dataset, sequence_length, output_features): |
|
"""Prefix language modeling objective used in Raffel et al. 2019.""" |
|
ds = dataset |
|
ds = select_random_chunk(ds, output_features=output_features, |
|
feature_key='targets', max_length=65536) |
|
ds = split_tokens_to_inputs_length(ds, output_features=output_features, |
|
sequence_length=sequence_length) |
|
ds = denoise( |
|
ds, |
|
output_features, |
|
inputs_fn=drop_nonnoise_tokens, |
|
targets_fn=drop_noise_tokens, |
|
noise_density=0.5, |
|
noise_mask_fn=random_prefix_noise_mask, |
|
) |
|
return ds |
|
|
|
|
|
def full_lm(dataset, sequence_length, output_features): |
|
"""Full language modeling objective with EOS only at document boundaries.""" |
|
ds = dataset |
|
ds = select_random_chunk(ds, output_features=output_features, |
|
feature_key='targets', max_length=65536) |
|
ds = seqio.preprocessors.append_eos(ds, output_features) |
|
ds = reduce_concat_tokens(ds, feature_key='targets', batch_size=128) |
|
|
|
ds = split_tokens(ds, max_tokens_per_segment=sequence_length['targets']) |
|
return ds |
|
|
|
|
|
@gin.configurable |
|
def select_random_chunk(dataset: tf.data.Dataset, |
|
output_features: Mapping[str, seqio.Feature], |
|
max_length: Optional[int] = None, |
|
feature_key: str = 'targets', |
|
additional_feature_keys: Optional[Sequence[str]] = None, |
|
passthrough_feature_keys: Optional[ |
|
Sequence[str]] = None, |
|
sequence_length: Optional[Mapping[str, int]] = None, |
|
uniform_random_start: bool = False, |
|
min_length: Optional[int] = None, |
|
**unused_kwargs) -> tf.data.Dataset: |
|
"""Token-preprocessor to extract one span of at most `max_length` tokens. |
|
|
|
If the token sequence is longer than `max_length`, then we return a random |
|
subsequence. Otherwise, we return the full sequence. |
|
|
|
This is generally followed by split_tokens. |
|
|
|
Args: |
|
dataset: A tf.data.Dataset with dictionaries containing the key feature_key. |
|
output_features: Mapping of keys to features. |
|
max_length: Typically specified in gin configs, takes priority over |
|
sequence_length. |
|
feature_key: Which feature to use from the dataset. |
|
additional_feature_keys: Additional features to use. The same chunk will be |
|
selected from these features as from the one specified in feature_key, |
|
so they should all have the same length. |
|
passthrough_feature_keys: Additional keys to pass through unchanged. |
|
sequence_length: Used if max_length is not specified. Typically passed in |
|
by the data pipeline. feature_key will be used to select the length. |
|
uniform_random_start: If True, will select a starting point in |
|
[-max_length + 1, n_tokens). If False, will select one of a set of chunks |
|
offset by max_length. Both of these starting points try to ensure each |
|
token has an equal probability of being included. |
|
min_length: If specified, lengths of chunks will be selected uniformly at |
|
random from [min_length, max_length]. Note that chunks can end up shorter |
|
than min_length if at the beginning or end of the sequence. |
|
|
|
Returns: |
|
a dataset |
|
""" |
|
if passthrough_feature_keys: |
|
chunk_keys = set([feature_key] + (additional_feature_keys or [])) |
|
overlap_keys = chunk_keys & set(passthrough_feature_keys) |
|
if overlap_keys: |
|
raise ValueError( |
|
f'chunk keys {overlap_keys} also included in passthrough keys') |
|
|
|
if max_length is None and sequence_length is not None: |
|
max_length = sequence_length[feature_key] |
|
if output_features[feature_key].add_eos: |
|
|
|
max_length -= 1 |
|
if max_length is None: |
|
raise ValueError('Must specify max_length or sequence_length.') |
|
|
|
@seqio.map_over_dataset(num_seeds=2) |
|
def _my_fn(x, seeds): |
|
"""Select a random chunk of tokens. |
|
|
|
Args: |
|
x: a 1d Tensor |
|
seeds: an int32 Tensor, shaped (2, 2), the random seeds. |
|
Returns: |
|
a 1d Tensor |
|
""" |
|
tokens = x[feature_key] |
|
n_tokens = tf.shape(tokens)[0] |
|
if min_length is not None: |
|
length = tf.random.stateless_uniform( |
|
[], |
|
minval=min_length, |
|
maxval=max_length, |
|
dtype=tf.int32, |
|
seed=seeds[0]) |
|
else: |
|
length = max_length |
|
if uniform_random_start: |
|
start = tf.random.stateless_uniform( |
|
[], |
|
minval=-length + 1, |
|
maxval=n_tokens, |
|
dtype=tf.int32, |
|
seed=seeds[1]) |
|
end = tf.minimum(start + length, n_tokens) |
|
start = tf.maximum(start, 0) |
|
else: |
|
num_segments = tf.cast( |
|
tf.math.ceil( |
|
tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32) |
|
), |
|
tf.int32) |
|
start = length * tf.random.stateless_uniform( |
|
[], |
|
maxval=num_segments, |
|
dtype=tf.int32, |
|
seed=seeds[1]) |
|
end = tf.minimum(start + length, n_tokens) |
|
chunk = {feature_key: tokens[start:end]} |
|
if additional_feature_keys is not None: |
|
for k in additional_feature_keys: |
|
with tf.control_dependencies([ |
|
tf.assert_equal( |
|
tf.shape(tokens)[0], |
|
tf.shape(x[k])[0], |
|
message=(f'Additional feature {k} is not the same size as ' |
|
f'{feature_key} along axis 0 in select_random_chunk().' |
|
) |
|
) |
|
]): |
|
chunk[k] = x[k][start:end] |
|
if passthrough_feature_keys is not None: |
|
for k in passthrough_feature_keys: |
|
chunk[k] = x[k] |
|
return chunk |
|
|
|
dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0)) |
|
return _my_fn(dataset) |
|
|
|
|
|
@gin.configurable |
|
def reduce_concat_tokens(dataset, |
|
feature_key='targets', |
|
batch_size=128, |
|
**unused_kwargs): |
|
"""Token-preprocessor to concatenate multiple unrelated documents. |
|
|
|
If we want to generate examples of exactly the right length, |
|
(to avoid wasting space on padding), then we use this function, folowed by |
|
split_tokens. |
|
|
|
Args: |
|
dataset: a tf.data.Dataset with dictionaries containing the key feature_key. |
|
feature_key: an string |
|
batch_size: an integer - how many documents to concatenate into one |
|
|
|
Returns: |
|
a dataset |
|
""" |
|
dataset = dataset.map( |
|
lambda x: {feature_key: x[feature_key]}, num_parallel_calls=AUTOTUNE) |
|
dataset = dataset.padded_batch(batch_size, padded_shapes={feature_key: [-1]}) |
|
def _my_fn(x): |
|
tokens = tf.reshape(x[feature_key], [-1]) |
|
|
|
tokens = tf.boolean_mask(tokens, tf.cast(tokens, tf.bool)) |
|
return {feature_key: tokens} |
|
|
|
return dataset.map(_my_fn, num_parallel_calls=AUTOTUNE) |
|
|
|
|
|
@seqio.map_over_dataset |
|
def trim_tokens_at_front(x, |
|
sequence_length, |
|
keys_to_trim=None, |
|
**unused_kwargs): |
|
"""Token-preprocessor to trim sequence at the beginning. |
|
|
|
Args: |
|
x: an example with dictionaries containing keys_to_trim. |
|
sequence_length: a dict of ints. |
|
keys_to_trim: a list of feature keys. |
|
|
|
Returns: |
|
A preprocessed example. |
|
""" |
|
|
|
for key in (keys_to_trim or sequence_length.keys()): |
|
if key in x: |
|
|
|
x[key] = x[key][-(sequence_length[key] - 1):] |
|
return x |
|
|
|
|
|
def trivia_qa_truncate_inputs(dataset, output_features, sequence_length): |
|
"""Token preprocessor for the trivia QA dataset to truncate inputs. |
|
|
|
This function takes a dataset containing "targets" and "inputs". It searches |
|
for the "targets" in the "inputs" and truncates the "inputs" to |
|
`sequence_length` while ensuring that the "targets" are present in the |
|
"inputs". The function will randomly select a subset of "inputs". |
|
If "targets" are not found in the "inputs", then the example is |
|
is dropped from the dataset. |
|
|
|
E.g. |
|
Input dataset |
|
{ |
|
"inputs": [0, 3, 5, 7, 9, 11, 13, 15, 17, 18] |
|
"targets": [5, 7, 9] |
|
} |
|
|
|
Output dataset (assuming sequence_length['inputs'] = 4) |
|
{ |
|
"inputs": [3, 5, 7, 9] |
|
"targets": [5, 7, 9] |
|
} |
|
|
|
or |
|
|
|
{ |
|
"inputs": [5, 7, 9, 11] |
|
"targets": [5, 7, 9] |
|
} |
|
Args: |
|
dataset: a tf.data.Dataset with dictionaries containing the "inputs" and |
|
"targets". |
|
output_features: unused by this function. |
|
sequence_length: a dict, with keys as "inputs" and "targets" indicating the |
|
maximum number of tokens in each of the sequences. |
|
|
|
Returns: |
|
a dataset |
|
|
|
""" |
|
|
|
del output_features |
|
|
|
@seqio.map_over_dataset(num_seeds=1) |
|
def my_fn(features, seed): |
|
"""Function to map original dataset to the new dataset.""" |
|
inputs = features['inputs'] |
|
targets = features['targets'] |
|
ans_len = tf.shape(targets)[0] |
|
max_input_tokens = sequence_length['inputs'] |
|
|
|
def truncate_inputs(): |
|
"""Helper function to truncate the inputs.""" |
|
|
|
def answer_in_context(context, answer): |
|
"""Helper function that checks if the answer is present in the context. |
|
|
|
Args: |
|
context: Tensor, tokenized representation of the context |
|
answer: Tensor, tokenized representation of the answer |
|
|
|
Returns: |
|
result: boolean, indicates if the answer was present in the context. |
|
pos_mask: boolean mask, a mask for every possible start position of |
|
the answer in the context. Indicates whether the answer starts at |
|
the particular position. |
|
""" |
|
conv_inp = tf.reshape(tf.cast(context, tf.float32), [1, -1, 1]) |
|
ans_len = tf.shape(answer)[0] |
|
filters = tf.eye(ans_len, dtype=tf.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
strided = tf.nn.conv1d(conv_inp, |
|
tf.reshape(filters, [ans_len, 1, ans_len]), 1, |
|
'VALID') |
|
strided = tf.cast(strided[0], answer.dtype) |
|
pos_mask = tf.reduce_all( |
|
tf.equal(strided, tf.reshape(answer, [1, -1])), 1) |
|
result = tf.reduce_any(pos_mask) |
|
return result, pos_mask |
|
|
|
def slice_inputs(inputs, answer_len, pos_mask, seed=None): |
|
"""Helper function to slice inputs while keeping the answer.""" |
|
ans_start_pos = tf.cast(tf.where(pos_mask)[0][0], tf.int32) |
|
inputs_len = tf.shape(inputs)[0] |
|
start_range_min = tf.maximum( |
|
0, ans_start_pos - (max_input_tokens - answer_len)) |
|
start_range_max = tf.minimum(ans_start_pos, |
|
inputs_len - max_input_tokens) + 1 |
|
|
|
start_pos = tf.random.stateless_uniform( |
|
[], |
|
minval=start_range_min, |
|
maxval=start_range_max, |
|
dtype=tf.int32, |
|
seed=seed) |
|
return inputs[start_pos:start_pos + max_input_tokens] |
|
|
|
result, pos_mask = answer_in_context(inputs, targets) |
|
|
|
if result: |
|
return slice_inputs(inputs, ans_len, pos_mask, seed=seed) |
|
else: |
|
return tf.constant([], dtype=inputs.dtype) |
|
|
|
if tf.greater(tf.shape(inputs)[0], max_input_tokens): |
|
inputs = truncate_inputs() |
|
return {'inputs': inputs, 'targets': features['targets']} |
|
|
|
dataset = my_fn(dataset) |
|
return dataset.filter(lambda x: tf.size(x['inputs']) > 0) |
|
|
|
|
|
@gin.configurable() |
|
def unsupervised(dataset, |
|
preprocessors=None, |
|
output_features=None, |
|
sequence_length=None): |
|
"""Configure this to point at unsupervised preprocessors. |
|
|
|
This function creates an extra level of indirection in case we want |
|
different unsupervised pretraining functions in the future which do not |
|
fit into the denoise() framework. |
|
|
|
This function should be used as a post-cache preprocessing function. |
|
|
|
Args: |
|
dataset: A tf.data.Dataset to process. |
|
preprocessors: a list of token-preprocessor functions. These functions |
|
should take unused kwargs if output_features or sequence_length is not |
|
used. |
|
output_features: dict(str, Feature), output features of the Task to be |
|
passed to the model. |
|
sequence_length: dict mapping feature key to int length for that feature. |
|
|
|
Returns: |
|
A preprocessed tf.data.Dataset. |
|
""" |
|
if preprocessors is None: |
|
logging.warning( |
|
'unsupervised preprocessor got preprocessors=None; no preprocessing ' |
|
'will be applied.' |
|
) |
|
return dataset |
|
|
|
kwargs = {} |
|
if output_features: |
|
kwargs['output_features'] = output_features |
|
if sequence_length: |
|
kwargs['sequence_length'] = sequence_length |
|
|
|
for p in preprocessors: |
|
dataset = p(dataset, **kwargs) |
|
return dataset |
|
|
|
|
|
|
|
|
|
@gin.configurable |
|
def split_tokens(dataset: tf.data.Dataset, |
|
min_tokens_per_segment: Optional[int] = None, |
|
max_tokens_per_segment: int = gin.REQUIRED, |
|
feature_key: str = 'targets', |
|
additional_feature_keys: Optional[Sequence[str]] = None, |
|
passthrough_feature_keys: Optional[Sequence[str]] = None, |
|
num_parallel_calls: int = AUTOTUNE, |
|
**unused_kwargs) -> tf.data.Dataset: |
|
"""Split examples into multiple examples each. |
|
|
|
The intended use case is to break up long examples for use in unsupervised |
|
transfer-learning. |
|
|
|
This function is generally preceded by select_random_chunk. |
|
|
|
If min_tokens_per_segment is provided, the segment length is chosen randomly |
|
per document from a log-uniform distribution. If min_tokens_per_segment is |
|
None, then the segment length is max_tokens_per_segment (except for a possibly |
|
shorter last segment in each document). |
|
|
|
Args: |
|
dataset: a tf.data.Dataset with dictionaries containing the key feature_key. |
|
min_tokens_per_segment: an optional integer |
|
max_tokens_per_segment: an integer, the maximum number of tokens in each |
|
segment. Only the final segment may be shorter. |
|
feature_key: a string, the feature to split |
|
additional_feature_keys: Additional features to split. The same chunk size |
|
will be used, so they should be the same size as feature_key. |
|
passthrough_feature_keys: Features to pass through without any splitting. |
|
num_parallel_calls: num_parallel_calls value to pass to map_over_dataset |
|
|
|
Returns: |
|
a dataset |
|
""" |
|
if passthrough_feature_keys: |
|
split_keys = set([feature_key] + (additional_feature_keys or [])) |
|
overlap_keys = split_keys & set(passthrough_feature_keys) |
|
if overlap_keys: |
|
raise ValueError( |
|
f'split keys {overlap_keys} also included in passthrough keys') |
|
|
|
@seqio.map_over_dataset(num_seeds=1, num_parallel_calls=num_parallel_calls) |
|
def _split_tokens(x, seed): |
|
"""Split one token sequence into multiple sequences.""" |
|
tokens = x[feature_key] |
|
n_tokens = tf.shape(tokens)[0] |
|
if min_tokens_per_segment is None: |
|
length = max_tokens_per_segment |
|
else: |
|
|
|
length = tf.cast( |
|
tf.exp( |
|
tf.random.stateless_uniform( |
|
[], |
|
minval=math.log(min_tokens_per_segment), |
|
maxval=math.log(max_tokens_per_segment), |
|
seed=seed |
|
) |
|
), |
|
tf.int32) |
|
|
|
|
|
|
|
num_segments = tf.cast( |
|
tf.math.ceil( |
|
tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32)) |
|
, |
|
tf.int32) |
|
padding = num_segments * length - tf.shape(tokens)[0] |
|
feature_keys_to_split = [feature_key] |
|
orig_lengths = {} |
|
outputs = {} |
|
if additional_feature_keys is not None: |
|
feature_keys_to_split.extend(additional_feature_keys) |
|
for k in feature_keys_to_split: |
|
with tf.control_dependencies([ |
|
tf.assert_equal( |
|
tf.shape(tokens)[0], |
|
tf.shape(x[k])[0], |
|
message=(f'Additional feature {k} is not the same size as ' |
|
f'{feature_key} along axis 0 in split_tokens().') |
|
) |
|
]): |
|
shape = tf.shape(x[k])[1:] |
|
shape_list = x[k].shape[1:] |
|
padded = tf.pad( |
|
x[k], |
|
tf.concat([[[0, padding]], |
|
tf.zeros([len(shape_list), 2], dtype=tf.int32)], |
|
axis=0)) |
|
orig_lengths[k] = tf.concat( |
|
[tf.repeat(length, num_segments - 1), [length - padding]], axis=0) |
|
outputs[k] = tf.reshape( |
|
padded, tf.concat([[-1, length], shape], axis=0)) |
|
if passthrough_feature_keys: |
|
for k in passthrough_feature_keys: |
|
outputs[k] = tf.tile( |
|
tf.expand_dims(x[k], axis=0), |
|
tf.concat([[num_segments], tf.tile([1], [tf.rank(x[k])])], axis=0)) |
|
return outputs, orig_lengths |
|
|
|
def _strip_padding(inputs, orig_lengths): |
|
output = {} |
|
for k, v in inputs.items(): |
|
if passthrough_feature_keys and k in passthrough_feature_keys: |
|
output[k] = v |
|
else: |
|
output[k] = v[:orig_lengths[k]] |
|
return output |
|
|
|
|
|
dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0)) |
|
dataset = _split_tokens(dataset) |
|
dataset = dataset.unbatch() |
|
dataset = dataset.map(_strip_padding, num_parallel_calls=AUTOTUNE) |
|
return dataset |
|
|
|
|
|
@gin.configurable |
|
def split_tokens_to_inputs_length(dataset, sequence_length, |
|
output_features, **kwargs): |
|
max_tokens = sequence_length['inputs'] |
|
if output_features['inputs'].add_eos: |
|
|
|
max_tokens -= 1 |
|
|
|
return split_tokens(dataset, max_tokens_per_segment=max_tokens, **kwargs) |
|
|
|
|
|
@gin.configurable |
|
def split_tokens_to_targets_length(dataset, sequence_length, |
|
output_features, **kwargs): |
|
max_tokens = sequence_length['targets'] |
|
if output_features['targets'].add_eos: |
|
|
|
max_tokens -= 1 |
|
|
|
return split_tokens(dataset, max_tokens_per_segment=max_tokens, **kwargs) |
|
|
|
|
|
@gin.configurable |
|
def split_tokens_to_random_length(dataset, sequence_length, |
|
output_features, **kwargs): |
|
max_tokens = sequence_length['inputs'] |
|
if output_features['inputs'].add_eos: |
|
|
|
max_tokens -= 1 |
|
|
|
return split_tokens(dataset, |
|
min_tokens_per_segment=8, |
|
max_tokens_per_segment=max_tokens, |
|
**kwargs) |
|
|
|
|
|
@gin.configurable |
|
def concatenate_and_split_to_fixed_length(dataset, |
|
sequence_length, |
|
output_features, |
|
feature_key='targets', |
|
**unused_kwargs): |
|
"""Concatenate tokens across examples, then split to fixed-size chunks. |
|
|
|
Chunk length is determined by sequence_length[feature_key]. |
|
|
|
Args: |
|
dataset: a tf.data.Dataset |
|
sequence_length: a dict of ints. |
|
output_features: a dict mapping feature name to t5.data.Feature. |
|
feature_key: a string |
|
Returns: |
|
a tf.data.Dataset |
|
""" |
|
dataset = dataset.map(lambda x: {feature_key: x[feature_key]}) |
|
max_tokens = sequence_length[feature_key] |
|
if output_features[feature_key].add_eos: |
|
|
|
max_tokens -= 1 |
|
return dataset.unbatch().batch(max_tokens) |
|
|
|
|
|
@gin.configurable |
|
def filter_by_string_length(dataset, |
|
feature_key='targets', |
|
min_length=1, |
|
max_length=1000000, |
|
**unused_kwargs): |
|
"""Filter examples by string length. |
|
|
|
Args: |
|
dataset: a tf.data.Dataset (not tokenized) |
|
feature_key: a string |
|
min_length: an integer |
|
max_length: an integer |
|
Returns: |
|
a tf.data.Dataset |
|
""" |
|
def my_fn(x): |
|
l = tf.strings.length(x[feature_key]) |
|
return tf.logical_and(tf.greater_equal(l, min_length), |
|
tf.less_equal(l, max_length)) |
|
return dataset.filter(my_fn) |
|
|
|
|
|
@gin.configurable |
|
def random_spans_helper(inputs_length=gin.REQUIRED, |
|
noise_density=gin.REQUIRED, |
|
mean_noise_span_length=gin.REQUIRED, |
|
extra_tokens_per_span_inputs=gin.REQUIRED, |
|
extra_tokens_per_span_targets=gin.REQUIRED, |
|
verbose=False): |
|
"""Training parameters to avoid padding with random_spans_noise_mask. |
|
|
|
When training a model with random_spans_noise_mask, we would like to set the |
|
other training hyperparmeters in a way that avoids padding. This function |
|
helps us compute these hyperparameters. |
|
|
|
We assume that each noise span in the input is replaced by |
|
extra_tokens_per_span_inputs sentinel tokens, and each non-noise span in the |
|
targets is replaced by extra_tokens_per_span_targets sentinel tokens. |
|
|
|
This function tells us the required number of tokens in the raw example (for |
|
split_tokens()) as well as the length of the encoded targets. |
|
|
|
Note that this function assumes the inputs and targets will have EOS appended |
|
and includes that in the reported length. |
|
|
|
Args: |
|
inputs_length: an integer - desired length of the tokenized inputs sequence |
|
noise_density: a float |
|
mean_noise_span_length: a float |
|
extra_tokens_per_span_inputs: an integer |
|
extra_tokens_per_span_targets: an integer |
|
verbose: a bool indicating whether to log sequence lengths |
|
Returns: |
|
tokens_length: length of original text in tokens |
|
targets_length: an integer - length in tokens of encoded targets sequence |
|
""" |
|
def _tokens_length_to_inputs_length_targets_length(tokens_length): |
|
num_noise_tokens = int(round(tokens_length * noise_density)) |
|
num_nonnoise_tokens = tokens_length - num_noise_tokens |
|
num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) |
|
|
|
|
|
return ( |
|
num_nonnoise_tokens + |
|
num_noise_spans * extra_tokens_per_span_inputs + 1, |
|
num_noise_tokens + |
|
num_noise_spans * extra_tokens_per_span_targets + 1) |
|
|
|
tokens_length = inputs_length - 1 |
|
while (_tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] |
|
<= inputs_length): |
|
tokens_length += 1 |
|
inputs_length, targets_length = ( |
|
_tokens_length_to_inputs_length_targets_length(tokens_length)) |
|
|
|
|
|
if noise_density == 0.5 and targets_length > inputs_length: |
|
tokens_length -= 1 |
|
targets_length -= 1 |
|
if verbose: |
|
logging.info( |
|
'tokens_length=%s inputs_length=%s targets_length=%s ' |
|
'noise_density=%s mean_noise_span_length=%s ', |
|
tokens_length, inputs_length, targets_length, |
|
noise_density, mean_noise_span_length) |
|
return tokens_length, targets_length |
|
|
|
|
|
@gin.configurable |
|
def random_spans_tokens_length(): |
|
"""Helper for gin-configuring split_tokens with random_spans_noise_mask.""" |
|
return random_spans_helper()[0] |
|
|
|
|
|
@gin.configurable |
|
def random_spans_targets_length(): |
|
"""Helper for gin-configuring the targets sequence length.""" |
|
return random_spans_helper()[1] |
|
|
|
|
|
|
|
|
|
|
|
@gin.configurable() |
|
def denoise(dataset, |
|
output_features, |
|
noise_density=gin.REQUIRED, |
|
noise_mask_fn=gin.REQUIRED, |
|
inputs_fn=gin.REQUIRED, |
|
targets_fn=None, |
|
passthrough_feature_keys: Optional[Sequence[str]] = None, |
|
input_feature_key='inputs', |
|
**unused_kwargs): |
|
"""Gin-configurable token preprocessor for self-supervised denoising tasks. |
|
|
|
This function takes a dataset containing "targets" sequences, |
|
and turns each sequence into a dictionary containing: |
|
{ |
|
"inputs": noisy version of the original sequence |
|
"targets": the full original sequence or missing parts of original sequence |
|
} |
|
|
|
In particular, for each sequence, we choose a boolean noise_mask identifying |
|
which tokens in the sequence to corrupt, as defined by the given |
|
noise_mask_fn. |
|
|
|
Given the sequence and the noise mask, we generate the inputs and targets |
|
using the given inputs_fn and targets_fn respectively. |
|
|
|
The self-supervised tasks vary along these axes: |
|
- noise_density: What fraction of the tokens to select as noise |
|
- noise_mask_fn: What pattern should the noise mask follow |
|
(iid, regular segments, etc.) |
|
- inputs_fn: How to apply the noise |
|
(drop noise tokens, replace with sentinels, etc.) |
|
- targets_fn: How to represent the output |
|
(full sequence, only non-noise tokens, etc.) |
|
|
|
Note: Some functionality has been deleted, which we may or may not want to |
|
restore at a later date. The code for this functionality can be found in |
|
the deleted code for this CL. In particular: |
|
- mixture of masking and random replacement |
|
- task labels prepended to the inputs |
|
|
|
Args: |
|
dataset: A tf.data.Dataset to process. |
|
output_features: a dict mapping feature name to t5.data.Feature. |
|
noise_density: a float |
|
noise_mask_fn: a function from (length, noise_density) -> boolean mask |
|
inputs_fn: a function from (tokens, noise_mask, vocabulary) -> tokens |
|
targets_fn: a function from (tokens, noise_mask, vocabulary) -> tokens |
|
passthrough_feature_keys: names of additional features to include in output |
|
input_feature_key: name of feature to use as inputs |
|
|
|
Returns: |
|
A preprocessed tf.data.Dataset. |
|
""" |
|
if passthrough_feature_keys and (input_feature_key in passthrough_feature_keys |
|
or 'targets' in passthrough_feature_keys): |
|
raise ValueError( |
|
f"passthrough keys cannot contain '{input_feature_key}' or 'targets'") |
|
|
|
@seqio.map_over_dataset(num_seeds=6) |
|
def my_fn(features, seeds): |
|
"""Map function.""" |
|
tokens = features['targets'] |
|
vocabulary = output_features['targets'].vocabulary |
|
if (input_feature_key in output_features and |
|
vocabulary != output_features[input_feature_key].vocabulary): |
|
raise ValueError( |
|
'denoise creates inputs based on tokenized targets but was applied ' |
|
'to a task that uses different vocabularies for inputs and targets.') |
|
noise_mask = noise_mask_fn(tf.size(tokens), noise_density, seeds=seeds[:2]) |
|
inputs = inputs_fn(tokens, noise_mask, vocabulary, seeds=seeds[2:4]) |
|
if targets_fn: |
|
targets = targets_fn(tokens, noise_mask, vocabulary, seeds=seeds[4:6]) |
|
else: |
|
targets = tokens |
|
return { |
|
input_feature_key: inputs, |
|
'targets': targets, |
|
**{ |
|
k: features[k] |
|
for k in features |
|
if passthrough_feature_keys and k in passthrough_feature_keys |
|
} |
|
} |
|
|
|
return my_fn(dataset) |
|
|
|
|
|
@gin.configurable() |
|
def iid_noise_mask(length, noise_density, seeds): |
|
"""Independent and identically distributed token noise. |
|
|
|
Args: |
|
length: an int32 scalar. |
|
noise_density: a float - approximate density of output mask. |
|
seeds: an int32 Tensor, shaped (1, 2), the random seed. |
|
|
|
Returns: |
|
a boolean tensor with shape [length]. |
|
""" |
|
return tf.random.stateless_uniform([length], seed=seeds[0]) < noise_density |
|
|
|
|
|
@gin.configurable() |
|
def regular_noise_mask(length, |
|
noise_density, |
|
seeds, |
|
min_span_length=1, |
|
max_span_length=5): |
|
"""Noise mask consisting of equally spaced spans of equal length. |
|
|
|
The span length and the offset are chosen randomly per-example. |
|
The beginning and end of the sequence may be part of shorter spans of noise. |
|
For example, if noise_density=0.25 and a span length of 2 is chosen, |
|
then the output might be: |
|
|
|
[T F F F F F F T T F F F F F F T T F F F F F F T T F F] |
|
|
|
Args: |
|
length: an int32 scalar. |
|
noise_density: a float - approximate density of output mask. |
|
seeds: an int32 Tensor, shaped (2, 2), the random seeds. |
|
min_span_length: an integer. |
|
max_span_length: an integer. |
|
|
|
Returns: |
|
a boolean tensor with shape [length]. |
|
""" |
|
span_length = tf.random.stateless_uniform( |
|
[], |
|
minval=min_span_length, |
|
maxval=max_span_length + 1, |
|
dtype=tf.int32, |
|
seed=seeds[0]) |
|
period = tf.cast( |
|
tf.round(tf.cast(span_length, tf.float32) / noise_density), tf.int32) |
|
offset = tf.random.stateless_uniform( |
|
[], |
|
maxval=period, |
|
dtype=tf.int32, |
|
seed=seeds[1]) |
|
return (tf.range(length, dtype=tf.int32) + offset) % period < span_length |
|
|
|
|
|
@gin.configurable() |
|
def random_spans_noise_mask(length, |
|
noise_density, |
|
seeds, |
|
mean_noise_span_length=3.0): |
|
"""Noise mask consisting of random spans of noise tokens. |
|
|
|
The number of noise tokens and the number of noise spans and non-noise spans |
|
are determined deterministically as follows: |
|
|
|
num_noise_tokens = round(length * noise_density) |
|
num_nonnoise_spans = num_noise_spans = round( |
|
num_noise_tokens / mean_noise_span_length) |
|
|
|
Spans alternate between non-noise and noise, beginning with non-noise. |
|
Subject to the above restrictions, all masks are equally likely. |
|
|
|
Args: |
|
length: an int32 scalar (length of the incoming token sequence) |
|
noise_density: a float - approximate density of output mask |
|
seeds: an int32 Tensor, shaped (2, 2) |
|
mean_noise_span_length: a number |
|
|
|
Returns: |
|
a boolean tensor with shape [length] |
|
""" |
|
|
|
orig_length = length |
|
|
|
length = tf.maximum(length, 2) |
|
def to_int(x): |
|
return tf.cast(x, tf.int32) |
|
def to_float(x): |
|
return tf.cast(x, tf.float32) |
|
num_noise_tokens = to_int(tf.round(to_float(length) * noise_density)) |
|
|
|
num_noise_tokens = tf.minimum(tf.maximum(num_noise_tokens, 1), length - 1) |
|
num_noise_spans = to_int( |
|
tf.round(to_float(num_noise_tokens) / mean_noise_span_length)) |
|
|
|
num_noise_spans = tf.maximum(num_noise_spans, 1) |
|
num_nonnoise_tokens = length - num_noise_tokens |
|
|
|
def _random_segmentation(num_items, num_segments, seed): |
|
"""Partition a sequence of items randomly into non-empty segments. |
|
|
|
Args: |
|
num_items: an integer scalar > 0 |
|
num_segments: an integer scalar in [1, num_items] |
|
seed: an integer seed |
|
Returns: |
|
a Tensor with shape [num_segments] containing positive integers that add |
|
up to num_items |
|
""" |
|
first_in_segment = tf.pad( |
|
seqio.stateless_shuffle( |
|
to_int(tf.range(num_items - 1) < num_segments - 1), |
|
seed), |
|
[[1, 0]]) |
|
segment_id = tf.cumsum(first_in_segment) |
|
segment_length = tf.math.segment_sum(tf.ones_like(segment_id), segment_id) |
|
return segment_length |
|
noise_span_lengths = _random_segmentation( |
|
num_noise_tokens, num_noise_spans, seeds[0]) |
|
nonnoise_span_lengths = _random_segmentation( |
|
num_nonnoise_tokens, num_noise_spans, seeds[1]) |
|
interleaved_span_lengths = tf.reshape( |
|
tf.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), |
|
[num_noise_spans * 2]) |
|
span_starts = tf.cumsum(interleaved_span_lengths)[:-1] |
|
span_start_indicator = tf.math.unsorted_segment_sum( |
|
tf.ones_like(span_starts), span_starts, length) |
|
span_num = tf.cumsum(span_start_indicator) |
|
is_noise = tf.equal(span_num % 2, 1) |
|
return is_noise[:orig_length] |
|
|
|
|
|
@gin.configurable() |
|
def random_prefix_noise_mask(length, noise_density, seeds): |
|
"""First part of the sequence is noise (for prefix_lm). |
|
|
|
The length of the prefix is chosen uniformly between [1, length) |
|
noise_density must be 0.5. |
|
TODO(noam): figure out some distribution to use if noise_density != 0.5. |
|
|
|
Args: |
|
length: an int32 scalar. |
|
noise_density: a float - must equal 0.5. |
|
seeds: an int32 Tensor, shaped (1, 2), the random seed. |
|
|
|
Returns: |
|
a boolean tensor with shape [length]. |
|
""" |
|
if noise_density != 0.5: |
|
raise NotImplementedError( |
|
'noise density must equal 0.5 for random_prefix_noise_mask') |
|
max_input_tokens = length - 1 |
|
min_input_tokens = tf.minimum(max_input_tokens, 1) |
|
num_input_tokens = tf.random.stateless_uniform( |
|
[], |
|
minval=min_input_tokens, |
|
maxval=max_input_tokens + 1, |
|
dtype=tf.int32, |
|
seed=seeds[0]) |
|
return tf.range(length, dtype=tf.int32) < num_input_tokens |
|
|
|
|
|
@gin.configurable() |
|
def sentinel_id(vocabulary, return_value=None): |
|
"""Token ID to use as a sentinel. |
|
|
|
By default, we use the last token in the vocabulary. |
|
|
|
Args: |
|
vocabulary: a t5.data.vocabularies.Vocabulary |
|
return_value: an optional integer |
|
Returns: |
|
an integer |
|
""" |
|
if return_value is not None: |
|
return return_value |
|
return vocabulary.vocab_size - 1 |
|
|
|
|
|
@gin.configurable() |
|
def noise_token_to_sentinel(tokens, noise_mask, vocabulary, seeds): |
|
"""Replace each noise token with the given sentinel. |
|
|
|
Args: |
|
tokens: a 1d integer Tensor |
|
noise_mask: a boolean Tensor with the same shape as tokens |
|
vocabulary: a vocabulary.Vocabulary |
|
seeds: an unused int32 Tensor |
|
|
|
Returns: |
|
a Tensor with the same shape and dtype as tokens |
|
""" |
|
del seeds |
|
return tf.where(noise_mask, |
|
tf.cast(sentinel_id(vocabulary), tokens.dtype), |
|
tokens) |
|
|
|
|
|
@gin.configurable() |
|
def noise_span_to_sentinel(tokens, noise_mask, vocabulary, seeds): |
|
"""Replace each run of consecutive noise tokens with a single sentinel. |
|
|
|
Args: |
|
tokens: a 1d integer Tensor |
|
noise_mask: a boolean Tensor with the same shape as tokens |
|
vocabulary: a vocabulary.Vocabulary |
|
seeds: an unused int32 Tensor |
|
Returns: |
|
a Tensor with the same shape and dtype as tokens |
|
""" |
|
del seeds |
|
tokens = tf.where(noise_mask, |
|
tf.cast(sentinel_id(vocabulary), tokens.dtype), |
|
tokens) |
|
prev_token_is_noise = tf.pad(noise_mask[:-1], [[1, 0]]) |
|
subsequent_noise_tokens = tf.logical_and(noise_mask, prev_token_is_noise) |
|
return tf.boolean_mask(tokens, tf.logical_not(subsequent_noise_tokens)) |
|
|
|
|
|
@gin.configurable() |
|
def nonnoise_span_to_sentinel(tokens, noise_mask, vocabulary, seeds): |
|
return noise_span_to_sentinel( |
|
tokens, tf.logical_not(noise_mask), vocabulary, seeds) |
|
|
|
|
|
@gin.configurable() |
|
def noise_span_to_unique_sentinel(tokens, noise_mask, vocabulary, seeds): |
|
"""Replace each run of consecutive noise tokens with a different sentinel. |
|
|
|
The idea here is to be able to align the dropped spans in the inputs |
|
with the markers in the targets. |
|
|
|
We want to generate training examples like |
|
"We hold X to be Y that" -> "X these truths Y self evident Z" |
|
|
|
Sentinels assigned in decreasing order within the sequence starting at |
|
vocabulary.size - 1. That is, we appropriate the last tokens in the |
|
vocabulary for additional use as sentinels. |
|
|
|
TODO(noam): we may want to try enlarging the vocabulary and leaving room |
|
for the sentinels instead. However, this requires enlarging the embedding |
|
tables in the model, so that is a bigger change. |
|
|
|
Args: |
|
tokens: a 1d integer Tensor |
|
noise_mask: a boolean Tensor with the same shape as tokens |
|
vocabulary: a vocabulary.Vocabulary |
|
seeds: an unused int32 Tensor |
|
Returns: |
|
a Tensor with the same shape and dtype as tokens |
|
""" |
|
del seeds |
|
|
|
prev_token_is_noise = tf.pad(noise_mask[:-1], [[1, 0]]) |
|
|
|
first_noise_tokens = tf.logical_and( |
|
noise_mask, tf.logical_not(prev_token_is_noise)) |
|
subsequent_noise_tokens = tf.logical_and(noise_mask, prev_token_is_noise) |
|
|
|
sentinel = sentinel_id(vocabulary) + 1 - tf.cumsum( |
|
tf.cast(first_noise_tokens, tokens.dtype)) |
|
|
|
tokens = tf.where(first_noise_tokens, sentinel, tokens) |
|
return tf.boolean_mask(tokens, tf.logical_not(subsequent_noise_tokens)) |
|
|
|
|
|
@gin.configurable() |
|
def nonnoise_span_to_unique_sentinel(tokens, noise_mask, vocabulary, seeds): |
|
return noise_span_to_unique_sentinel( |
|
tokens, tf.logical_not(noise_mask), vocabulary, seeds) |
|
|
|
|
|
@gin.configurable() |
|
def drop_noise_tokens(tokens, noise_mask, vocabulary, seeds): |
|
"""Drop noise tokens without inserting a sentinel. |
|
|
|
Args: |
|
tokens: a 1d integer Tensor |
|
noise_mask: a boolean Tensor with the same shape as tokens |
|
vocabulary: an unused vocabulary.Vocabulary |
|
seeds: an unused int32 Tensor |
|
|
|
Returns: |
|
a Tensor with the same shape and dtype as tokens |
|
""" |
|
del vocabulary |
|
del seeds |
|
return tf.boolean_mask(tokens, tf.logical_not(noise_mask)) |
|
|
|
|
|
@gin.configurable() |
|
def drop_nonnoise_tokens(tokens, noise_mask, vocabulary, seeds): |
|
"""Drop non-noise tokens without inserting a sentinel. |
|
|
|
Args: |
|
tokens: a 1d integer Tensor |
|
noise_mask: a boolean Tensor with the same shape as tokens |
|
vocabulary: an unused vocabulary.Vocabulary |
|
seeds: an unused int32 Tensor |
|
Returns: |
|
a Tensor with the same shape and dtype as tokens |
|
""" |
|
del vocabulary |
|
del seeds |
|
return tf.boolean_mask(tokens, noise_mask) |
|
|
|
|
|
@gin.configurable() |
|
def permute_noise_tokens(tokens, noise_mask, vocabulary, seeds): |
|
"""Permute the noise tokens, keeping the non-noise tokens where they are. |
|
|
|
Args: |
|
tokens: a 1d integer Tensor |
|
noise_mask: a boolean Tensor with the same shape as tokens |
|
vocabulary: an unused vocabulary.Vocabulary |
|
seeds: an int32 Tensor, sized (1, 2) |
|
Returns: |
|
a Tensor with the same shape and dtype as tokens |
|
""" |
|
del vocabulary |
|
|
|
masked_only = tf.boolean_mask(tokens, noise_mask) |
|
permuted = seqio.stateless_shuffle(masked_only, seeds[0]) |
|
|
|
permuted = tf.pad(permuted, [[0, 1]]) |
|
indices = tf.cumsum(tf.cast(noise_mask, tf.int32), exclusive=True) |
|
return tf.where(noise_mask, |
|
tf.gather(permuted, indices), |
|
tokens) |
|
|
|
|
|
@gin.configurable() |
|
def noise_token_to_gathered_token(tokens, noise_mask, vocabulary, seeds): |
|
"""Replace each noise token with a random token from the sequence. |
|
|
|
Args: |
|
tokens: a 1d integer Tensor |
|
noise_mask: a boolean Tensor with the same shape as tokens |
|
vocabulary: an unused vocabulary.Vocabulary |
|
seeds: an int32 Tensor, sized (1, 2) |
|
Returns: |
|
a Tensor with the same shape and dtype as tokens |
|
""" |
|
del vocabulary |
|
|
|
indices = tf.random.stateless_uniform( |
|
shape=tf.shape(tokens), |
|
maxval=tf.size(tokens), |
|
dtype=tf.int32, |
|
seed=seeds[0]) |
|
return tf.where(noise_mask, |
|
tf.gather(tokens, indices), |
|
tokens) |
|
|
|
|
|
@gin.configurable() |
|
def noise_token_to_random_token( |
|
tokens, |
|
noise_mask, |
|
vocabulary, |
|
seeds, |
|
num_reserved_tokens=3): |
|
"""Replace each noise token with a random token from the vocabulary. |
|
|
|
|
|
|
|
Args: |
|
tokens: a 1d integer Tensor |
|
noise_mask: a boolean Tensor with the same shape as tokens |
|
vocabulary: a vocabulary.Vocabulary |
|
seeds: an int32 Tensor, shaped (1, 2) |
|
num_reserved_tokens: an integer |
|
Returns: |
|
a Tensor with the same shape and dtype as tokens |
|
""" |
|
return tf.where(noise_mask, |
|
tf.random.stateless_uniform( |
|
tf.shape(tokens), |
|
minval=num_reserved_tokens, |
|
maxval=vocabulary.vocab_size, |
|
dtype=tokens.dtype, |
|
seed=seeds[0]), |
|
tokens) |
|
|
|
|
|
@gin.configurable() |
|
def noise_token_to_random_token_or_sentinel( |
|
tokens, |
|
noise_mask, |
|
vocabulary, |
|
seeds, |
|
random_prob=0.1): |
|
"""Replace each noise token with a random token or a sentinel. |
|
|
|
For each masked token, with probability random_prob, we replace it by a |
|
random token from the vocabulary. Otherwise, we replace it with a sentinel. |
|
|
|
Args: |
|
tokens: a 1d integer Tensor |
|
noise_mask: a boolean Tensor with the same shape as tokens |
|
vocabulary: a vocabulary.Vocabulary |
|
seeds: an int32 Tensor, shaped (2, 2). |
|
random_prob: a float |
|
Returns: |
|
a Tensor with the same shape and dtype as tokens |
|
""" |
|
use_random = ( |
|
tf.random.stateless_uniform(tf.shape(tokens), seed=seeds[0]) < |
|
random_prob) |
|
return tf.where( |
|
use_random, |
|
noise_token_to_random_token( |
|
tokens, noise_mask, vocabulary, seeds=seeds[1:]), |
|
noise_token_to_sentinel( |
|
tokens, noise_mask, vocabulary, seeds=())) |
|
|
|
|
|
|
|
|
|
|
|
def trim_and_pad_dataset(dataset, sequence_length): |
|
"""A wrapper to use `seqio.utils.trim_and_pad_dataset` as a preprocessor.""" |
|
return seqio.utils.trim_and_pad_dataset( |
|
dataset, feature_lengths=sequence_length) |
|
|
|
|
|
def targets_for_prefix_lm_objective(dataset, sequence_length, output_features): |
|
"""Prepares targets to be used for prefix LM objective.""" |
|
dataset = select_random_chunk( |
|
dataset, output_features, max_length=65536, feature_key='targets') |
|
dataset = seqio.preprocessors.append_eos(dataset, output_features) |
|
dataset = reduce_concat_tokens(dataset, batch_size=128) |
|
dataset = split_tokens( |
|
dataset, max_tokens_per_segment=sequence_length['targets']) |
|
dataset = trim_and_pad_dataset(dataset, sequence_length) |
|
return dataset |
|
|
|
|
|
def pack_prefix_lm_encoder_decoder(ds, sequence_length, pad_id=0): |
|
"""Pack two examples into one with the prefix LM objective.""" |
|
packed_length = next(iter(sequence_length.values())) |
|
assert packed_length % 2 == 0 |
|
assert all(l == packed_length for l in sequence_length.values()) |
|
|
|
@seqio.utils.map_over_dataset(num_seeds=1) |
|
def pack_examples(example_pair, seed): |
|
split_point = tf.random.stateless_uniform((), |
|
minval=1, |
|
maxval=packed_length, |
|
seed=seed, |
|
dtype=tf.int32) |
|
inputs = tf.concat([ |
|
example_pair['targets'][0][:split_point], |
|
example_pair['targets'][1][:packed_length - split_point] |
|
], |
|
axis=0) |
|
inputs = tf.reshape(inputs, (packed_length,)) |
|
targets = tf.concat([ |
|
example_pair['targets'][0][split_point:], |
|
example_pair['targets'][1][packed_length - split_point:] |
|
], |
|
axis=0) |
|
targets = tf.reshape(targets, (packed_length,)) |
|
|
|
encoder_segment_ids = tf.cast( |
|
tf.range(packed_length) >= split_point, tf.int32) + 1 |
|
decoder_segment_ids = tf.cast( |
|
tf.range(packed_length) >= (packed_length - split_point), tf.int32) + 1 |
|
|
|
decoder_input_tokens = seqio.utils.make_autoregressive_inputs( |
|
targets, sequence_id=decoder_segment_ids) |
|
|
|
encoder_positions = tf.concat( |
|
[tf.range(split_point), |
|
tf.range(packed_length - split_point)], axis=0) |
|
encoder_positions = tf.reshape(encoder_positions, (packed_length,)) |
|
decoder_positions = tf.concat( |
|
[tf.range(packed_length - split_point), |
|
tf.range(split_point)], axis=0) |
|
decoder_positions = tf.reshape(decoder_positions, (packed_length,)) |
|
decoder_loss_weights = tf.cast( |
|
tf.not_equal(targets, pad_id), dtype=tf.int32) |
|
return { |
|
'encoder_input_tokens': inputs, |
|
'decoder_target_tokens': targets, |
|
'decoder_input_tokens': decoder_input_tokens, |
|
'encoder_segment_ids': encoder_segment_ids, |
|
'encoder_positions': encoder_positions, |
|
'decoder_segment_ids': decoder_segment_ids, |
|
'decoder_positions': decoder_positions, |
|
'decoder_loss_weights': decoder_loss_weights, |
|
} |
|
|
|
|
|
return pack_examples(ds.batch(2)) |
|
|
|
|
|
def pack_prefix_lm_decoder_only(ds, |
|
sequence_length, |
|
loss_on_targets_only=True, |
|
pad_id=0): |
|
"""Randomly split the tokens for the prefix LM objective.""" |
|
packed_length = next(iter(sequence_length.values())) |
|
assert packed_length % 2 == 0 |
|
assert all(l == packed_length for l in sequence_length.values()) |
|
|
|
@seqio.utils.map_over_dataset(num_seeds=1) |
|
def pack_examples(example, seed): |
|
split_point = tf.random.stateless_uniform((), |
|
minval=1, |
|
maxval=packed_length, |
|
seed=seed, |
|
dtype=tf.int32) |
|
decoder_target_tokens = example['targets'] |
|
decoder_input_tokens = seqio.utils.make_autoregressive_inputs( |
|
decoder_target_tokens) |
|
|
|
if loss_on_targets_only: |
|
decoder_loss_weights = tf.cast( |
|
tf.range(packed_length) >= split_point, tf.int32) |
|
else: |
|
decoder_loss_weights = tf.ones((packed_length,), dtype=tf.int32) |
|
|
|
padding_mask = tf.cast( |
|
tf.not_equal(decoder_target_tokens, pad_id), dtype=tf.int32) |
|
decoder_loss_weights *= padding_mask |
|
|
|
decoder_causal_attention = tf.cast( |
|
tf.range(packed_length) <= split_point, tf.int32) |
|
|
|
return { |
|
'decoder_target_tokens': decoder_target_tokens, |
|
'decoder_input_tokens': decoder_input_tokens, |
|
'decoder_loss_weights': decoder_loss_weights, |
|
'decoder_causal_attention': decoder_causal_attention, |
|
} |
|
|
|
return pack_examples(ds) |
|
|