|
import functools |
|
from typing import Dict |
|
|
|
import seqio |
|
import tensorflow as tf |
|
from datasets import load_dataset, load_from_disk |
|
from t5.evaluation import metrics |
|
from seqio import utils, FunctionDataSource |
|
import t5.data |
|
from datasets import load_dataset, load_from_disk |
|
from t5.data import postprocessors |
|
from t5.data import preprocessors |
|
|
|
|
|
from ul2_objective import ul2_objective |
|
|
|
|
|
R_DENOISER_SPAN_LENGTHS = [3.0, 8.0] |
|
X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0] |
|
R_DENOISER_CORRUPT_RATES = [0.15, 0.15] |
|
X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5] |
|
|
|
R_DENOISER_TOKEN_PREFIX = "[NLU]" |
|
X_DENOISER_TOKEN_PREFIX = "[NLG]" |
|
S_DENOISER_TOKEN_PREFIX = "[S2S]" |
|
|
|
TaskRegistry = seqio.TaskRegistry |
|
|
|
vocabulary = seqio.SentencePieceVocabulary('spiece.model') |
|
|
|
DEFAULT_OUTPUT_FEATURES = { |
|
"inputs": seqio.Feature(vocabulary=vocabulary, add_eos=True, required=False), |
|
"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True), |
|
} |
|
|
|
def gen_dataset(split, shuffle=False, seed=None, column="text", path=None, name=None): |
|
dataset = load_dataset(path, name, streaming=True, use_auth_token=True) |
|
if shuffle: |
|
if seed: |
|
dataset = dataset.shuffle(seed=seed) |
|
else: |
|
dataset = dataset.shuffle() |
|
while True: |
|
for item in dataset[str(split)]: |
|
yield item[column] |
|
|
|
|
|
def dataset_fn(split, shuffle_files, seed=None, path=None, name=None): |
|
return tf.data.Dataset.from_generator( |
|
functools.partial( |
|
gen_dataset, split, shuffle_files, seed, path=path, name=name |
|
), |
|
output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=path), |
|
) |
|
|
|
|
|
@utils.map_over_dataset |
|
def target_to_key(x, key_map, target_key): |
|
"""Assign the value from the dataset to target_key in key_map""" |
|
return {**key_map, target_key: x} |
|
|
|
|
|
dataset_name = 'Siddharth63/biological_dataset' |
|
dataset = load_dataset(dataset_name) |
|
|
|
dataset_shapes = {"train": dataset["train"].num_rows, |
|
"validation": dataset["validation"].num_rows} |
|
|
|
TaskRegistry.add( |
|
"pretrain_biological_ul2", |
|
source=seqio.FunctionDataSource( |
|
dataset_fn=functools.partial( |
|
dataset_fn, path="Siddharth63/biological_dataset", |
|
), |
|
splits=("train", "validation"), |
|
caching_permitted=False, |
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
target_to_key, |
|
key_map={ |
|
"inputs": "text", |
|
"targets": "text", |
|
}, |
|
target_key="targets", |
|
), |
|
seqio.preprocessors.tokenize, |
|
functools.partial( |
|
ul2_objective, |
|
shard_ds=False, |
|
use_prefix_lm_task=True, |
|
rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)] * len(R_DENOISER_SPAN_LENGTHS) |
|
+ [0.4 / len(X_DENOISER_SPAN_LENGTHS)] * len(X_DENOISER_SPAN_LENGTHS) |
|
+ [ |
|
0.2 |
|
], |
|
mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS, |
|
noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES, |
|
optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX] |
|
* len(R_DENOISER_SPAN_LENGTHS) |
|
+ [X_DENOISER_TOKEN_PREFIX] * len(X_DENOISER_SPAN_LENGTHS) |
|
+ [S_DENOISER_TOKEN_PREFIX], |
|
reserved_for_packing=1, |
|
), |
|
seqio.preprocessors.append_eos_after_trim, |
|
], |
|
output_features={ |
|
"targets": DEFAULT_OUTPUT_FEATURES["targets"], |
|
"inputs": seqio.Feature(vocabulary=vocabulary, add_eos=True), |
|
}, |
|
metric_fns=[metrics.accuracy], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|