File size: 1,878 Bytes
05cd399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# /home/perk/mymodel/categorisation-mt5x/tasks.py

import functools
import seqio
import tensorflow_datasets as tfds
from t5.evaluation import metrics
from t5.data import preprocessors
import t5
import tensorflow.compat.v1 as tf

tsv_path = {
        "train": "gs://peregilcloud/italian_tweets/train.tsv",
        "validation": "gs://peregilcloud/italian_tweets/dev.tsv",
        "test": "gs://peregilcloud/italian_tweets/test.tsv"
}

vocabulary = seqio.SentencePieceVocabulary(
                'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)

DEFAULT_OUTPUT_FEATURES = {
    "inputs":
        seqio.Feature(
            vocabulary=vocabulary, add_eos=True),
    "targets":
        seqio.Feature(
            vocabulary=vocabulary, add_eos=True)
}

def categorise_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 [normalize_text(ex["source"])]),
        "targets": 
	    tf.strings.join(
                 [normalize_text(ex["target"])]),
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)


seqio.TaskRegistry.add(
    "categorise",
    source=seqio.TextLineDataSource(
        split_to_filepattern=tsv_path,
        #num_input_examples=num_nq_examples
        ),
    preprocessors=[
      functools.partial(
          t5.data.preprocessors.parse_tsv,
          field_names=["target","source"]),
      categorise_preprocessor,
      seqio.preprocessors.tokenize_and_append_eos,
    ],
    #metric_fns=[metrics.bleu],
    output_features=DEFAULT_OUTPUT_FEATURES,
)