|
|
|
|
|
|
|
import functools |
|
import seqio |
|
import my_metrics |
|
import tensorflow_datasets as tfds |
|
from t5.evaluation import metrics |
|
from t5.data import preprocessors |
|
|
|
import t5 |
|
import tensorflow.compat.v1 as tf |
|
|
|
|
|
|
|
tsv_parliament_path = { |
|
"train": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/train.tsv", |
|
"validation": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/dev.tsv", |
|
"test": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/test.tsv" |
|
} |
|
|
|
tsv_parliament_max300_path = { |
|
"train": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv_max300/train.tsv", |
|
"validation": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv_max300/dev.tsv", |
|
"test": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv_max300/test.tsv" |
|
} |
|
|
|
|
|
tsv_translate_path = { |
|
"train": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/train.tsv", |
|
"validation": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/dev.tsv", |
|
"test": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/test.tsv" |
|
} |
|
|
|
tsv_translate_long_path = { |
|
"train": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/train_long.tsv", |
|
"validation": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/dev.tsv", |
|
"test": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/test.tsv" |
|
} |
|
|
|
tsv_sentiment_path = { |
|
"train": "gs://notram-public/finetune_datasets/norec_sentiment/train.tsv", |
|
"validation": "gs://notram-public/finetune_datasets/norec_sentiment/dev.tsv", |
|
"test": "gs://notram-public/finetune_datasets/norec_sentiment/test.tsv" |
|
} |
|
|
|
json_angry_tweets_path = { |
|
"train": "gs://notram-public/finetune_datasets/angry_tweets/train.jsonl", |
|
"validation": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl", |
|
"test": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl" |
|
} |
|
|
|
tsv_angry_tweets_path = { |
|
"train": "gs://notram-public/finetune_datasets/angry_tweets/train.tsv", |
|
"validation": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv", |
|
"test": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv" |
|
} |
|
|
|
|
|
tsv_dane_path = { |
|
"train": "gs://notram-public/finetune_datasets/dane/train.tsv", |
|
"validation": "gs://notram-public/finetune_datasets/dane/test.tsv", |
|
"test": "gs://notram-public/finetune_datasets/dane/test.tsv" |
|
} |
|
|
|
tsv_dane_tokens_path = { |
|
"train": "gs://notram-public/finetune_datasets/dane/train_tokens.tsv", |
|
"validation": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv", |
|
"test": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv" |
|
} |
|
|
|
|
|
tsv_dane_long_tokens_path = { |
|
"train": "gs://notram-public/finetune_datasets/dane/train_long_tokens.tsv", |
|
"validation": "gs://notram-public/finetune_datasets/dane/test_long_tokens.tsv", |
|
"test": "gs://notram-public/finetune_datasets/dane/test_long_tokens.tsv" |
|
} |
|
|
|
|
|
|
|
|
|
scand_vocabulary=seqio.SentencePieceVocabulary('gs://nb-t5/t5/vocabs/wikipedia/no-da-en-sv-nn-is_32000_unigram.sp.model', extra_ids=100) |
|
eng_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model', extra_ids=0) |
|
mt5_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0) |
|
|
|
DEFAULT_OUTPUT_FEATURES = { |
|
"inputs": seqio.Feature( |
|
vocabulary=eng_vocabulary, add_eos=True, |
|
required=False), |
|
"targets": seqio.Feature( |
|
vocabulary=eng_vocabulary, add_eos=True) |
|
} |
|
|
|
|
|
SCAND_OUTPUT_FEATURES = { |
|
"inputs": seqio.Feature( |
|
vocabulary=scand_vocabulary, add_eos=True, |
|
required=False), |
|
"targets": seqio.Feature( |
|
vocabulary=scand_vocabulary, add_eos=True) |
|
} |
|
|
|
MT5_OUTPUT_FEATURES = { |
|
"inputs": seqio.Feature( |
|
vocabulary=mt5_vocabulary, add_eos=True, |
|
required=False), |
|
"targets": seqio.Feature( |
|
vocabulary=mt5_vocabulary, add_eos=True) |
|
} |
|
|
|
|
|
|
|
def categorise_preprocessor(ds): |
|
def normalize_text(text): |
|
"""Lowercase and remove quotes from a TensorFlow string.""" |
|
|
|
... |
|
return text |
|
|
|
def to_inputs_and_targets(ex): |
|
"""Map {"source": ..., "source": ...}->{"target": ..., "target": ...}.""" |
|
return { |
|
"inputs": |
|
tf.strings.join( |
|
[normalize_text(ex["source"])]), |
|
"targets": |
|
tf.strings.join( |
|
[normalize_text(ex["target"])]), |
|
} |
|
return ds.map(to_inputs_and_targets, |
|
num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
|
|
|
|
seqio.TaskRegistry.add( |
|
"parliament_max300", |
|
source=seqio.TextLineDataSource( |
|
split_to_filepattern=tsv_parliament_max300_path, |
|
|
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
t5.data.preprocessors.parse_tsv, |
|
field_names=["target","source"]), |
|
categorise_preprocessor, |
|
seqio.preprocessors.tokenize_and_append_eos, |
|
], |
|
metric_fns=[metrics.accuracy,my_metrics.f1_macro], |
|
output_features=DEFAULT_OUTPUT_FEATURES, |
|
) |
|
|
|
|
|
seqio.TaskRegistry.add( |
|
"parliament_max300_scand", |
|
source=seqio.TextLineDataSource( |
|
split_to_filepattern=tsv_parliament_max300_path, |
|
|
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
t5.data.preprocessors.parse_tsv, |
|
field_names=["target","source"]), |
|
categorise_preprocessor, |
|
seqio.preprocessors.tokenize_and_append_eos, |
|
], |
|
metric_fns=[metrics.accuracy,my_metrics.f1_macro], |
|
output_features=SCAND_OUTPUT_FEATURES, |
|
) |
|
|
|
|
|
seqio.TaskRegistry.add( |
|
"parliament_max300_mt5", |
|
source=seqio.TextLineDataSource( |
|
split_to_filepattern=tsv_parliament_max300_path, |
|
|
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
t5.data.preprocessors.parse_tsv, |
|
field_names=["target","source"]), |
|
categorise_preprocessor, |
|
seqio.preprocessors.tokenize_and_append_eos, |
|
], |
|
metric_fns=[metrics.accuracy,my_metrics.f1_macro], |
|
output_features=MT5_OUTPUT_FEATURES, |
|
) |
|
|
|
seqio.TaskRegistry.add( |
|
"sentiment", |
|
source=seqio.TextLineDataSource( |
|
split_to_filepattern=tsv_sentiment_path, |
|
|
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
t5.data.preprocessors.parse_tsv, |
|
field_names=["target","source"]), |
|
categorise_preprocessor, |
|
seqio.preprocessors.tokenize_and_append_eos, |
|
], |
|
metric_fns=[metrics.accuracy,my_metrics.f1_macro], |
|
output_features=DEFAULT_OUTPUT_FEATURES, |
|
) |
|
|
|
|
|
seqio.TaskRegistry.add( |
|
"translate", |
|
source=seqio.TextLineDataSource( |
|
split_to_filepattern=tsv_translate_path, |
|
|
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
t5.data.preprocessors.parse_tsv, |
|
field_names=["source","target"]), |
|
categorise_preprocessor, |
|
seqio.preprocessors.tokenize_and_append_eos, |
|
], |
|
metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu], |
|
output_features=DEFAULT_OUTPUT_FEATURES, |
|
) |
|
|
|
seqio.TaskRegistry.add( |
|
"translate_long_scand", |
|
source=seqio.TextLineDataSource( |
|
split_to_filepattern=tsv_translate_long_path, |
|
|
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
t5.data.preprocessors.parse_tsv, |
|
field_names=["source","target"]), |
|
categorise_preprocessor, |
|
seqio.preprocessors.tokenize_and_append_eos, |
|
], |
|
metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu], |
|
output_features=SCAND_OUTPUT_FEATURES, |
|
) |
|
|
|
seqio.TaskRegistry.add( |
|
"translate_long", |
|
source=seqio.TextLineDataSource( |
|
split_to_filepattern=tsv_translate_long_path, |
|
|
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
t5.data.preprocessors.parse_tsv, |
|
field_names=["source","target"]), |
|
categorise_preprocessor, |
|
seqio.preprocessors.tokenize_and_append_eos, |
|
], |
|
metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu], |
|
output_features=DEFAULT_OUTPUT_FEATURES, |
|
) |
|
|
|
|