|
import functools |
|
import seqio |
|
import tensorflow as tf |
|
import t5.data |
|
from datasets import load_dataset |
|
from t5.data import postprocessors |
|
from t5.data import preprocessors |
|
from t5.evaluation import metrics |
|
from seqio import FunctionDataSource, utils |
|
|
|
TaskRegistry = seqio.TaskRegistry |
|
vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model', extra_ids=0) |
|
|
|
DEFAULT_OUTPUT_FEATURES = { |
|
"inputs": seqio.Feature( |
|
vocabulary=vocabulary, add_eos=True, |
|
required=False), |
|
"targets": seqio.Feature( |
|
vocabulary=vocabulary, add_eos=True) |
|
} |
|
|
|
|
|
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None): |
|
dataset = load_dataset(**dataset_params) |
|
if shuffle: |
|
if seed: |
|
dataset = dataset.shuffle(seed=seed) |
|
else: |
|
dataset = dataset.shuffle() |
|
while True: |
|
for item in dataset[str(split)]: |
|
yield item[column] |
|
|
|
|
|
def dataset_fn(split, shuffle_files, seed=None, dataset_params=None): |
|
return tf.data.Dataset.from_generator( |
|
functools.partial(gen_dataset, split, shuffle_files, seed, dataset_params=dataset_params), |
|
output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name) |
|
) |
|
|
|
|
|
|
|
@utils.map_over_dataset |
|
def target_to_key(x, key_map, target_key): |
|
"""Assign the value from the dataset to target_key in key_map""" |
|
return {**key_map, target_key: x} |
|
|
|
|
|
|
|
|
|
dataset_name = 'NbAiLab/scandinavian' |
|
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True} |
|
dataset_shapes = None |
|
TaskRegistry.add( |
|
"ncc_scandinavian_span_corruption_stream", |
|
source=seqio.FunctionDataSource( |
|
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params), |
|
splits=("train", "validation"), |
|
caching_permitted=False, |
|
num_input_examples=dataset_shapes, |
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
target_to_key, key_map={ |
|
"inputs": None, |
|
"targets": None, |
|
}, target_key="targets"), |
|
seqio.preprocessors.tokenize, |
|
|
|
preprocessors.span_corruption, |
|
seqio.preprocessors.append_eos_after_trim, |
|
], |
|
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]}, |
|
metric_fns=[] |
|
) |
|
|
|
|
|
|
|
|