|
import functools |
|
import seqio |
|
from t5.evaluation import metrics |
|
from t5.data import preprocessors |
|
|
|
vocabulary = seqio.SentencePieceVocabulary('spiece.model') |
|
output_features = { |
|
'inputs': seqio.Feature(vocabulary=vocabulary, add_eos=True, required=False), |
|
'targets': seqio.Feature(vocabulary=vocabulary, add_eos=True) |
|
} |
|
|
|
seqio.TaskRegistry.add( |
|
'pretrain_finnish', |
|
source=seqio.TextLineDataSource({ |
|
"train": "/researchdisk/lm_training_dataset_full_sentences/train.txt", |
|
"validation": "/researchdisk/lm_training_dataset_full_sentences/validation.txt" |
|
}), |
|
preprocessors=[ |
|
functools.partial( |
|
preprocessors.parse_tsv, |
|
field_names=["text"], |
|
field_delim="\n"), |
|
functools.partial( |
|
preprocessors.rekey, key_map={ |
|
"inputs": None, |
|
"targets": "text" |
|
}), |
|
seqio.preprocessors.tokenize, |
|
seqio.CacheDatasetPlaceholder(), |
|
preprocessors.span_corruption, |
|
seqio.preprocessors.append_eos_after_trim, |
|
], |
|
metric_fns=[metrics.accuracy], |
|
output_features=output_features) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|