File size: 2,860 Bytes
4557487
 
 
 
81de315
4557487
 
 
 
 
 
2b404b0
4f33d95
 
 
4557487
 
2b404b0
 
 
 
 
 
4557487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b404b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4557487
2b404b0
4557487
 
 
 
 
 
 
 
 
81de315
4557487
b3a728f
4557487
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# /home/perk/mymodel/categorisation-mt5x/tasks.py

import functools
import seqio
import my_metrics
import tensorflow_datasets as tfds
from t5.evaluation import metrics
from t5.data import preprocessors
import t5
import tensorflow.compat.v1 as tf

tsv_parliament_path = {
        "train": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/train.tsv",
        "validation": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/dev.tsv",
        "test": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/test.tsv"
}

tsv_sentiment_path = {
        "train": "gs://notram-public/finetune_datasets/norec_sentiment/train.tsv",
        "validation": "gs://notram-public/finetune_datasets/norec_sentiment/dev.tsv",
        "test": "gs://notram-public/finetune_datasets/norec_sentiment/test.tsv"
}

vocabulary = seqio.SentencePieceVocabulary(
                'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)

DEFAULT_OUTPUT_FEATURES = {
    "inputs":
        seqio.Feature(
            vocabulary=vocabulary, add_eos=True),
    "targets":
        seqio.Feature(
            vocabulary=vocabulary, add_eos=True)
}

def categorise_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 [normalize_text(ex["source"])]),
        "targets": 
	    tf.strings.join(
                 [normalize_text(ex["target"])]),
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)


seqio.TaskRegistry.add(
    "parliament",
    source=seqio.TextLineDataSource(
        split_to_filepattern=tsv_parliament_path,
        #num_input_examples=num_nq_examples
        ),
    preprocessors=[
      functools.partial(
          t5.data.preprocessors.parse_tsv,
          field_names=["target","source"]),
      categorise_preprocessor,
      seqio.preprocessors.tokenize_and_append_eos,
    ],
    metric_fns=[metrics.accuracy,my_metrics.f1_macro],
    output_features=DEFAULT_OUTPUT_FEATURES,
)   

seqio.TaskRegistry.add(
    "sentiment",
    source=seqio.TextLineDataSource(
        split_to_filepattern=tsv_sentiment_path,
        #num_input_examples=num_nq_examples
        ),
    preprocessors=[
      functools.partial(
          t5.data.preprocessors.parse_tsv,
          field_names=["target","source"]),
      categorise_preprocessor,
      seqio.preprocessors.tokenize_and_append_eos,
    ],
    metric_fns=[metrics.accuracy,my_metrics.f1_macro],
    output_features=DEFAULT_OUTPUT_FEATURES,
)