# /home/perk/mymodel/categorisation-mt5x/tasks.py import functools import seqio import my_metrics import tensorflow_datasets as tfds from t5.evaluation import metrics from t5.data import preprocessors #import my_preprocessors import t5 import tensorflow.compat.v1 as tf tsv_parliament_path = { "train": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/train.tsv", "validation": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/dev.tsv", "test": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/test.tsv" } tsv_summary_path = { "train": "gs://north-t5x/corpus/summary_test/norwegian_train.tsv", "validation": "gs://north-t5x/corpus/summary_test/test.tsv", "test": "gs://north-t5x/corpus/summary_test/test.tsv" } tsv_summary_all_path = { "train": "gs://north-t5x/corpus/summary_test/cnn_and_norwegian_train.tsv", "validation": "gs://north-t5x/corpus/summary_test/test.tsv", "test": "gs://north-t5x/corpus/summary_test/test.tsv" } tsv_translate_path = { "train": "gs://nb-t5x-us-central2/corpus_en_no/train.tsv", "validation": "gs://nb-t5x-us-central2/corpus_en_no/dev.tsv", "test": "gs://nb-t5x-us-central2/corpus_en_no/test.tsv" } tsv_sentiment_path = { "train": "gs://notram-public/finetune_datasets/norec_sentiment/train.tsv", "validation": "gs://notram-public/finetune_datasets/norec_sentiment/dev.tsv", "test": "gs://notram-public/finetune_datasets/norec_sentiment/test.tsv" } json_angry_tweets_path = { "train": "gs://notram-public/finetune_datasets/angry_tweets/train.jsonl", "validation": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl", "test": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl" } tsv_angry_tweets_path = { "train": "gs://notram-public/finetune_datasets/angry_tweets/train.tsv", "validation": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv", "test": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv" } tsv_dane_path = { "train": "gs://notram-public/finetune_datasets/dane/train.tsv", "validation": "gs://notram-public/finetune_datasets/dane/test.tsv", "test": "gs://notram-public/finetune_datasets/dane/test.tsv" } tsv_dane_tokens_path = { "train": "gs://notram-public/finetune_datasets/dane/train_tokens.tsv", "validation": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv", "test": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv" } tsv_dane_long_tokens_path = { "train": "gs://notram-public/finetune_datasets/dane/train_long_tokens.tsv", "validation": "gs://notram-public/finetune_datasets/dane/test_long_tokens.tsv", "test": "gs://notram-public/finetune_datasets/dane/test_long_tokens.tsv" } scand_vocabulary=seqio.SentencePieceVocabulary('gs://nb-t5/t5/vocabs/wikipedia/no-da-en-sv-nn-is_32000_unigram.sp.model', extra_ids=100) eng_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model', extra_ids=0) mt5_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0) DEFAULT_OUTPUT_FEATURES = { "inputs": seqio.Feature( vocabulary=eng_vocabulary, add_eos=True, required=False), "targets": seqio.Feature( vocabulary=eng_vocabulary, add_eos=True) } SCAND_OUTPUT_FEATURES = { "inputs": seqio.Feature( vocabulary=scand_vocabulary, add_eos=True, required=False), "targets": seqio.Feature( vocabulary=scand_vocabulary, add_eos=True) } MT5_OUTPUT_FEATURES = { "inputs": seqio.Feature( vocabulary=mt5_vocabulary, add_eos=True, required=False), "targets": seqio.Feature( vocabulary=mt5_vocabulary, add_eos=True) } def categorise_preprocessor(ds): def normalize_text(text): """Lowercase and remove quotes from a TensorFlow string.""" #text = tf.strings.regex_replace(text,"'(.*)'", r"\1") ... return text def to_inputs_and_targets(ex): """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}.""" return { "inputs": tf.strings.join( [normalize_text(ex["source"])]), "targets": tf.strings.join( [normalize_text(ex["target"])]), } return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) seqio.TaskRegistry.add( "parliament", source=seqio.TextLineDataSource( split_to_filepattern=tsv_parliament_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["target","source"]), categorise_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro], output_features=DEFAULT_OUTPUT_FEATURES, ) seqio.TaskRegistry.add( "sentiment", source=seqio.TextLineDataSource( split_to_filepattern=tsv_sentiment_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["target","source"]), categorise_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro], output_features=DEFAULT_OUTPUT_FEATURES, ) seqio.TaskRegistry.add( "angry_tweets", source=seqio.TextLineDataSource( split_to_filepattern=tsv_angry_tweets_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["target","source"]), categorise_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro], output_features=DEFAULT_OUTPUT_FEATURES, ) seqio.TaskRegistry.add( "dane", source=seqio.TextLineDataSource( split_to_filepattern=tsv_dane_long_tokens_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["placeholder1","placeholder2","placeholder3","target","source"]), categorise_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro], output_features=DEFAULT_OUTPUT_FEATURES, ) seqio.TaskRegistry.add( "summary_scand", source=seqio.TextLineDataSource( split_to_filepattern=tsv_summary_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["source","target"]), categorise_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu,metrics.rouge], output_features=SCAND_OUTPUT_FEATURES, ) seqio.TaskRegistry.add( "summary", source=seqio.TextLineDataSource( split_to_filepattern=tsv_summary_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["source","target"]), categorise_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu,metrics.rouge], output_features=MT5_OUTPUT_FEATURES, ) seqio.TaskRegistry.add( "summary_all", source=seqio.TextLineDataSource( split_to_filepattern=tsv_summary_all_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["source","target"]), categorise_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu,metrics.rouge], output_features=MT5_OUTPUT_FEATURES, ) seqio.TaskRegistry.add( "summary_all_scand", source=seqio.TextLineDataSource( split_to_filepattern=tsv_summary_all_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["source","target"]), categorise_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu,metrics.rouge], output_features=SCAND_OUTPUT_FEATURES, )