|
|
|
|
|
import functools |
|
|
|
import seqio |
|
import tensorflow as tf |
|
import t5.data |
|
from datasets import load_dataset, load_from_disk |
|
from t5.data import postprocessors |
|
from t5.data import preprocessors |
|
from t5.evaluation import metrics |
|
from seqio import FunctionDataSource, utils |
|
|
|
TaskRegistry = seqio.TaskRegistry |
|
|
|
vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0) |
|
|
|
DEFAULT_OUTPUT_FEATURES = { |
|
"inputs": seqio.Feature( |
|
vocabulary=vocabulary, add_eos=True, |
|
required=False), |
|
"targets": seqio.Feature( |
|
vocabulary=vocabulary, add_eos=True) |
|
} |
|
|
|
|
|
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None): |
|
if shuffle: |
|
if seed: |
|
dataset = dataset.shuffle(seed=seed) |
|
else: |
|
dataset = dataset.shuffle() |
|
while True: |
|
for item in dataset[str(split)]: |
|
yield item[column] |
|
|
|
|
|
def dataset_fn(split, shuffle_files, seed=None, dataset=None): |
|
return tf.data.Dataset.from_generator( |
|
functools.partial(gen_dataset, split, shuffle_files, seed, dataset=dataset), |
|
output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name) |
|
) |
|
|
|
|
|
@utils.map_over_dataset |
|
def target_to_key(x, key_map, target_key): |
|
"""Assign the value from the dataset to target_key in key_map""" |
|
return {**key_map, target_key: x} |
|
|
|
|
|
|
|
dataset_name = "/researchdisk/lm_training_dataset_full" |
|
dataset_params = {"from_disk_path": dataset_name} |
|
|
|
if "from_disk_path" in dataset_params: |
|
dataset = load_from_disk(dataset_params.get("from_disk_path")) |
|
else: |
|
dataset = load_dataset(**dataset_params) |
|
|
|
dataset_shapes = {"train": dataset["train"].num_rows, "validation": dataset["validation"].num_rows} |
|
TaskRegistry.add( |
|
"pretrain_finnish", |
|
source=seqio.FunctionDataSource( |
|
dataset_fn=functools.partial(dataset_fn, dataset=dataset), |
|
splits=("train", "validation"), |
|
caching_permitted=False, |
|
num_input_examples=dataset_shapes, |
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
target_to_key, key_map={ |
|
"inputs": None, |
|
"targets": None, |
|
}, target_key="targets"), |
|
seqio.preprocessors.tokenize, |
|
|
|
preprocessors.span_corruption, |
|
seqio.preprocessors.append_eos_after_trim, |
|
], |
|
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]}, |
|
metric_fns=[metrics.accuracy] |
|
) |