pere's picture
mt5 vocab
213f85b
# /home/perk/mymodel/categorisation-mt5x/tasks.py
import functools
import seqio
import my_metrics
import tensorflow_datasets as tfds
from t5.evaluation import metrics
from t5.data import preprocessors
#import my_preprocessors
import t5
import tensorflow.compat.v1 as tf
tsv_parliament_path = {
"train": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/train.tsv",
"validation": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/dev.tsv",
"test": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/test.tsv"
}
tsv_parliament_max300_path = {
"train": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv_max300/train.tsv",
"validation": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv_max300/dev.tsv",
"test": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv_max300/test.tsv"
}
tsv_translate_path = {
"train": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/train.tsv",
"validation": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/dev.tsv",
"test": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/test.tsv"
}
tsv_translate_long_path = {
"train": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/train_long.tsv",
"validation": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/dev.tsv",
"test": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/test.tsv"
}
tsv_sentiment_path = {
"train": "gs://notram-public/finetune_datasets/norec_sentiment/train.tsv",
"validation": "gs://notram-public/finetune_datasets/norec_sentiment/dev.tsv",
"test": "gs://notram-public/finetune_datasets/norec_sentiment/test.tsv"
}
json_angry_tweets_path = {
"train": "gs://notram-public/finetune_datasets/angry_tweets/train.jsonl",
"validation": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl",
"test": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl"
}
tsv_angry_tweets_path = {
"train": "gs://notram-public/finetune_datasets/angry_tweets/train.tsv",
"validation": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv",
"test": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv"
}
tsv_dane_path = {
"train": "gs://notram-public/finetune_datasets/dane/train.tsv",
"validation": "gs://notram-public/finetune_datasets/dane/test.tsv",
"test": "gs://notram-public/finetune_datasets/dane/test.tsv"
}
tsv_dane_tokens_path = {
"train": "gs://notram-public/finetune_datasets/dane/train_tokens.tsv",
"validation": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv",
"test": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv"
}
tsv_dane_long_tokens_path = {
"train": "gs://notram-public/finetune_datasets/dane/train_long_tokens.tsv",
"validation": "gs://notram-public/finetune_datasets/dane/test_long_tokens.tsv",
"test": "gs://notram-public/finetune_datasets/dane/test_long_tokens.tsv"
}
#vocabulary = seqio.SentencePieceVocabulary(
# 'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
scand_vocabulary=seqio.SentencePieceVocabulary('gs://nb-t5/t5/vocabs/wikipedia/no-da-en-sv-nn-is_32000_unigram.sp.model', extra_ids=100)
eng_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model', extra_ids=0)
mt5_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
DEFAULT_OUTPUT_FEATURES = {
"inputs": seqio.Feature(
vocabulary=eng_vocabulary, add_eos=True,
required=False),
"targets": seqio.Feature(
vocabulary=eng_vocabulary, add_eos=True)
}
SCAND_OUTPUT_FEATURES = {
"inputs": seqio.Feature(
vocabulary=scand_vocabulary, add_eos=True,
required=False),
"targets": seqio.Feature(
vocabulary=scand_vocabulary, add_eos=True)
}
MT5_OUTPUT_FEATURES = {
"inputs": seqio.Feature(
vocabulary=mt5_vocabulary, add_eos=True,
required=False),
"targets": seqio.Feature(
vocabulary=mt5_vocabulary, add_eos=True)
}
def categorise_preprocessor(ds):
def normalize_text(text):
"""Lowercase and remove quotes from a TensorFlow string."""
#text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
...
return text
def to_inputs_and_targets(ex):
"""Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
return {
"inputs":
tf.strings.join(
[normalize_text(ex["source"])]),
"targets":
tf.strings.join(
[normalize_text(ex["target"])]),
}
return ds.map(to_inputs_and_targets,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
seqio.TaskRegistry.add(
"parliament_max300",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_parliament_max300_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["target","source"]),
categorise_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,my_metrics.f1_macro],
output_features=DEFAULT_OUTPUT_FEATURES,
)
seqio.TaskRegistry.add(
"parliament_max300_scand",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_parliament_max300_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["target","source"]),
categorise_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,my_metrics.f1_macro],
output_features=SCAND_OUTPUT_FEATURES,
)
seqio.TaskRegistry.add(
"parliament_max300_mt5",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_parliament_max300_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["target","source"]),
categorise_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,my_metrics.f1_macro],
output_features=MT5_OUTPUT_FEATURES,
)
seqio.TaskRegistry.add(
"sentiment",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_sentiment_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["target","source"]),
categorise_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,my_metrics.f1_macro],
output_features=DEFAULT_OUTPUT_FEATURES,
)
seqio.TaskRegistry.add(
"translate",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_translate_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["source","target"]),
categorise_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu],
output_features=DEFAULT_OUTPUT_FEATURES,
)
seqio.TaskRegistry.add(
"translate_long_scand",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_translate_long_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["source","target"]),
categorise_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu],
output_features=SCAND_OUTPUT_FEATURES,
)
seqio.TaskRegistry.add(
"translate_long",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_translate_long_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["source","target"]),
categorise_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu],
output_features=DEFAULT_OUTPUT_FEATURES,
)