pere's picture
updated dataset
e957cc1
# /home/perk/mymodel/categorisation-mt5x/tasks.py
import functools
import seqio
import tensorflow_datasets as tfds
from t5.evaluation import metrics
import my_metrics
from t5.data import preprocessors
import t5
import tensorflow.compat.v1 as tf
tsv_path = {
"train": "gs://north-t5x/corpus/deuncaser/norwegian/train.tsv",
"validation": "gs://north-t5x/corpus/deuncaser/norwegian/validation.tsv",
"test": "gs://north-t5x/corpus/deuncaser/norwegian/validation.tsv"
}
vocabulary = seqio.SentencePieceVocabulary(
'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
DEFAULT_OUTPUT_FEATURES = {
"inputs":
seqio.Feature(
vocabulary=vocabulary, add_eos=True),
"targets":
seqio.Feature(
vocabulary=vocabulary, add_eos=True)
}
def categorise_preprocessor(ds):
def normalize_text(text):
"""Lowercase and remove quotes from a TensorFlow string."""
text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
return text
def to_inputs_and_targets(ex):
"""Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
return {
"inputs":
tf.strings.join(
[normalize_text(ex["source"])]),
"targets":
tf.strings.join(
[normalize_text(ex["target"])]),
}
return ds.map(to_inputs_and_targets,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
def categorise_fulltext_preprocessor(ds):
def normalize_text(text):
"""Lowercase and remove quotes from a TensorFlow string."""
text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
return text
def fulltext(t):
if t=="0":
t="il testo è favorevole alla vaccinazione"
elif t=="1":
t="il testo è neutro rispetto alla vaccinazione"
elif t=="2":
t="is testo è sfavorevole alla vaccinazione"
return t
def to_inputs_and_targets(ex):
"""Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
return {
"inputs":
tf.strings.join(
[normalize_text(ex["source"])]),
"targets":
tf.strings.join(
[fulltext(normalize_text(ex["target"]))]),
}
return ds.map(to_inputs_and_targets,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
def categorise_fulltext_word_preprocessor(ds):
def normalize_text(text):
"""Lowercase and remove quotes from a TensorFlow string."""
text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
return text
def fulltext(t):
if t=="0":
t="promozionale"
elif t=="1":
t="neutro"
elif t=="2":
t="scoraggiante"
return t
def to_inputs_and_targets(ex):
"""Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
return {
"inputs":
tf.strings.join(
[normalize_text(ex["source"])]),
"targets":
tf.strings.join(
[fulltext(normalize_text(ex["target"]))]),
}
return ds.map(to_inputs_and_targets,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
def categorise_binary_preprocessor(ds):
def normalize_text(text):
"""Lowercase and remove quotes from a TensorFlow string."""
text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
return text
def fulltext(t):
if t=="0":
t="1"
elif t=="1":
t="1"
elif t=="2":
t="2"
return t
def to_inputs_and_targets(ex):
"""Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
return {
"inputs":
tf.strings.join(
[normalize_text(ex["source"])]),
"targets":
tf.strings.join(
[fulltext(normalize_text(ex["target"]))]),
}
return ds.map(to_inputs_and_targets,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
seqio.TaskRegistry.add(
"deuncaser",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["id","methods","source","target"]),
categorise_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,metrics.bleu],
output_features=DEFAULT_OUTPUT_FEATURES,
)
seqio.TaskRegistry.add(
"classify_tweets_fulltext",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["annotator1","annotator2","annotator3","target","source","id"]),
categorise_fulltext_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,my_metrics.f1_macro],
output_features=DEFAULT_OUTPUT_FEATURES,
)
seqio.TaskRegistry.add(
"classify_tweets_binary",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["annotator1","annotator2","annotator3","target","source","id"]),
categorise_binary_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,my_metrics.f1_macro],
output_features=DEFAULT_OUTPUT_FEATURES,
)
seqio.TaskRegistry.add(
"classify_tweets_fulltext_word",
source=seqio.TextLineDataSource(
split_to_filepattern=tsv_path,
#num_input_examples=num_nq_examples
),
preprocessors=[
functools.partial(
t5.data.preprocessors.parse_tsv,
field_names=["annotator1","annotator2","annotator3","target","source","id"]),
categorise_fulltext_word_preprocessor,
seqio.preprocessors.tokenize_and_append_eos,
],
metric_fns=[metrics.accuracy,my_metrics.f1_macro],
output_features=DEFAULT_OUTPUT_FEATURES,
)