aapot's picture
Saving weights and logs of step 20000
fdc60c1
raw history blame
No virus
1.54 kB
import functools
import seqio
from t5.evaluation import metrics
from t5.data import preprocessors
vocabulary = seqio.SentencePieceVocabulary('spiece.model')
output_features = {
'inputs': seqio.Feature(vocabulary=vocabulary, add_eos=True, required=False),
'targets': seqio.Feature(vocabulary=vocabulary, add_eos=True)
}
seqio.TaskRegistry.add(
'pretrain_finnish',
source=seqio.TextLineDataSource({
"train": "/researchdisk/lm_training_dataset_full_sentences/train.txt",
"validation": "/researchdisk/lm_training_dataset_full_sentences/validation.txt"
}),
preprocessors=[
functools.partial(
preprocessors.parse_tsv,
field_names=["text"],
field_delim="\n"),
functools.partial(
preprocessors.rekey, key_map={
"inputs": None,
"targets": "text"
}),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
preprocessors.span_corruption,
seqio.preprocessors.append_eos_after_trim,
],
metric_fns=[metrics.accuracy],
output_features=output_features)
# dataset = seqio.get_mixture_or_task("pretrain_finnish").get_dataset(
# sequence_length={"inputs": 512, "targets": 114},
# split="train",
# shuffle=True,
# num_epochs=1,
# #shard_info=seqio.ShardInfo(index=0, num_shards=10),
# use_cached=False,
# seed=42
# )
# # Print the first 5 examples.
# for _, ex in zip(range(5), dataset.as_numpy_iterator()):
# print(ex)