import functools | |
import seqio | |
from t5.evaluation import metrics | |
from t5.data import preprocessors | |
vocabulary = seqio.SentencePieceVocabulary('spiece.model') | |
output_features = { | |
'inputs': seqio.Feature(vocabulary=vocabulary, add_eos=True, required=False), | |
'targets': seqio.Feature(vocabulary=vocabulary, add_eos=True) | |
} | |
seqio.TaskRegistry.add( | |
'pretrain_finnish', | |
source=seqio.TextLineDataSource({ | |
"train": "/researchdisk/lm_training_dataset_full_sentences/train.txt", | |
"validation": "/researchdisk/lm_training_dataset_full_sentences/validation.txt" | |
}), | |
preprocessors=[ | |
functools.partial( | |
preprocessors.parse_tsv, | |
field_names=["text"], | |
field_delim="\n"), | |
functools.partial( | |
preprocessors.rekey, key_map={ | |
"inputs": None, | |
"targets": "text" | |
}), | |
seqio.preprocessors.tokenize, | |
seqio.CacheDatasetPlaceholder(), | |
preprocessors.span_corruption, | |
seqio.preprocessors.append_eos_after_trim, | |
], | |
metric_fns=[metrics.accuracy], | |
output_features=output_features) | |
# dataset = seqio.get_mixture_or_task("pretrain_finnish").get_dataset( | |
# sequence_length={"inputs": 512, "targets": 114}, | |
# split="train", | |
# shuffle=True, | |
# num_epochs=1, | |
# #shard_info=seqio.ShardInfo(index=0, num_shards=10), | |
# use_cached=False, | |
# seed=42 | |
# ) | |
# # Print the first 5 examples. | |
# for _, ex in zip(range(5), dataset.as_numpy_iterator()): | |
# print(ex) | |