# /home/perk/mymodel/categorisation-mt5x/tasks.py import functools import seqio import tensorflow_datasets as tfds from t5.evaluation import metrics import my_metrics from t5.data import preprocessors import t5 import tensorflow.compat.v1 as tf tsv_path = { "train": "gs://north-t5x/corpus/deuncaser/norwegian/train.tsv", "validation": "gs://north-t5x/corpus/deuncaser/norwegian/validation.tsv", "test": "gs://north-t5x/corpus/deuncaser/norwegian/validation.tsv" } vocabulary = seqio.SentencePieceVocabulary( 'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0) DEFAULT_OUTPUT_FEATURES = { "inputs": seqio.Feature( vocabulary=vocabulary, add_eos=True), "targets": seqio.Feature( vocabulary=vocabulary, add_eos=True) } def categorise_preprocessor(ds): def normalize_text(text): """Lowercase and remove quotes from a TensorFlow string.""" text = tf.strings.regex_replace(text,"'(.*)'", r"\1") return text def to_inputs_and_targets(ex): """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}.""" return { "inputs": tf.strings.join( [normalize_text(ex["source"])]), "targets": tf.strings.join( [normalize_text(ex["target"])]), } return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) def categorise_fulltext_preprocessor(ds): def normalize_text(text): """Lowercase and remove quotes from a TensorFlow string.""" text = tf.strings.regex_replace(text,"'(.*)'", r"\1") return text def fulltext(t): if t=="0": t="il testo è favorevole alla vaccinazione" elif t=="1": t="il testo è neutro rispetto alla vaccinazione" elif t=="2": t="is testo è sfavorevole alla vaccinazione" return t def to_inputs_and_targets(ex): """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}.""" return { "inputs": tf.strings.join( [normalize_text(ex["source"])]), "targets": tf.strings.join( [fulltext(normalize_text(ex["target"]))]), } return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) def categorise_fulltext_word_preprocessor(ds): def normalize_text(text): """Lowercase and remove quotes from a TensorFlow string.""" text = tf.strings.regex_replace(text,"'(.*)'", r"\1") return text def fulltext(t): if t=="0": t="promozionale" elif t=="1": t="neutro" elif t=="2": t="scoraggiante" return t def to_inputs_and_targets(ex): """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}.""" return { "inputs": tf.strings.join( [normalize_text(ex["source"])]), "targets": tf.strings.join( [fulltext(normalize_text(ex["target"]))]), } return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) def categorise_binary_preprocessor(ds): def normalize_text(text): """Lowercase and remove quotes from a TensorFlow string.""" text = tf.strings.regex_replace(text,"'(.*)'", r"\1") return text def fulltext(t): if t=="0": t="1" elif t=="1": t="1" elif t=="2": t="2" return t def to_inputs_and_targets(ex): """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}.""" return { "inputs": tf.strings.join( [normalize_text(ex["source"])]), "targets": tf.strings.join( [fulltext(normalize_text(ex["target"]))]), } return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) seqio.TaskRegistry.add( "deuncaser", source=seqio.TextLineDataSource( split_to_filepattern=tsv_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["id","methods","source","target"]), categorise_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,metrics.bleu], output_features=DEFAULT_OUTPUT_FEATURES, ) seqio.TaskRegistry.add( "classify_tweets_fulltext", source=seqio.TextLineDataSource( split_to_filepattern=tsv_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["annotator1","annotator2","annotator3","target","source","id"]), categorise_fulltext_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro], output_features=DEFAULT_OUTPUT_FEATURES, ) seqio.TaskRegistry.add( "classify_tweets_binary", source=seqio.TextLineDataSource( split_to_filepattern=tsv_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["annotator1","annotator2","annotator3","target","source","id"]), categorise_binary_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro], output_features=DEFAULT_OUTPUT_FEATURES, ) seqio.TaskRegistry.add( "classify_tweets_fulltext_word", source=seqio.TextLineDataSource( split_to_filepattern=tsv_path, #num_input_examples=num_nq_examples ), preprocessors=[ functools.partial( t5.data.preprocessors.parse_tsv, field_names=["annotator1","annotator2","annotator3","target","source","id"]), categorise_fulltext_word_preprocessor, seqio.preprocessors.tokenize_and_append_eos, ], metric_fns=[metrics.accuracy,my_metrics.f1_macro], output_features=DEFAULT_OUTPUT_FEATURES, )