pere commited on
Commit
ad029a3
1 Parent(s): c1c88b6

first deuncaser submit

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. finetune_deuncaser_base.gin +39 -0
  3. my_metrics.py +7 -0
  4. tasks.py +207 -0
README.md CHANGED
@@ -1 +1 @@
1
- Placeholder for a model to come...
 
1
+ Private sample code for running categorisation on the mT5X
finetune_deuncaser_base.gin ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import tasks
3
+
4
+ import __main__ as train_script
5
+ from t5.data import mixtures
6
+ from t5x import models
7
+ from t5x import partitioning
8
+ from t5x import utils
9
+
10
+ include "t5x/examples/t5/mt5/base.gin"
11
+ include "t5x/configs/runs/finetune.gin"
12
+
13
+ MIXTURE_OR_TASK_NAME = %gin.REQUIRED
14
+ TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 64}
15
+ INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
16
+ TRAIN_STEPS = %gin.REQUIRED # 1000000 pre-trained steps + 10000 fine-tuning steps.
17
+ USE_CACHED_TASKS = False
18
+ DROPOUT_RATE = 0.1
19
+ RANDOM_SEED = 0
20
+
21
+ #Fixing a small error
22
+ infer_eval/utils.DatasetConfig:
23
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
24
+
25
+ #Saving every 1000 steps
26
+ utils.SaveCheckpointConfig:
27
+ period = 1000
28
+
29
+
30
+ # Pere: Only necessary if we load a t5 model. We can start with an t5x model here
31
+ # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained
32
+ # using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be
33
+ # set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1:
34
+ # `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
35
+ # LOSS_NORMALIZING_FACTOR = 234496
36
+
37
+ # Might have to ba changed based on architecture
38
+ # partitioning.PjitPartitioner.num_partitions = 1
39
+
my_metrics.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sklearn.metrics
2
+ import numpy as np
3
+
4
+ def f1_macro(targets, predictions):
5
+ targets, predictions = np.asarray(targets).astype(str), np.asarray(predictions).astype(str)
6
+ return {"f1_macro": 100*sklearn.metrics.f1_score(targets, predictions, average='macro')}
7
+
tasks.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /home/perk/mymodel/categorisation-mt5x/tasks.py
2
+
3
+ import functools
4
+ import seqio
5
+ import tensorflow_datasets as tfds
6
+ from t5.evaluation import metrics
7
+ import my_metrics
8
+ from t5.data import preprocessors
9
+ import t5
10
+ import tensorflow.compat.v1 as tf
11
+
12
+ tsv_path = {
13
+ "train": "gs://eu-jav-t5x/corpus/labeled/datasetA_train_3categories.tsv",
14
+ "validation": "gs://eu-jav-t5x/corpus/labeled/datasetA_dev_3categories.tsv",
15
+ "test": "gs://eu-jav-t5x/corpus/labeled/ datasetA_test_3categories.tsv"
16
+ }
17
+
18
+ vocabulary = seqio.SentencePieceVocabulary(
19
+ 'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
20
+
21
+ DEFAULT_OUTPUT_FEATURES = {
22
+ "inputs":
23
+ seqio.Feature(
24
+ vocabulary=vocabulary, add_eos=True),
25
+ "targets":
26
+ seqio.Feature(
27
+ vocabulary=vocabulary, add_eos=True)
28
+ }
29
+
30
+ def categorise_preprocessor(ds):
31
+ def normalize_text(text):
32
+ """Lowercase and remove quotes from a TensorFlow string."""
33
+ text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
34
+ return text
35
+
36
+ def to_inputs_and_targets(ex):
37
+ """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
38
+ return {
39
+ "inputs":
40
+ tf.strings.join(
41
+ [normalize_text(ex["source"])]),
42
+ "targets":
43
+ tf.strings.join(
44
+ [normalize_text(ex["target"])]),
45
+ }
46
+
47
+ return ds.map(to_inputs_and_targets,
48
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
49
+
50
+
51
+ def categorise_fulltext_preprocessor(ds):
52
+ def normalize_text(text):
53
+ """Lowercase and remove quotes from a TensorFlow string."""
54
+ text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
55
+ return text
56
+
57
+ def fulltext(t):
58
+ if t=="0":
59
+ t="il testo è favorevole alla vaccinazione"
60
+ elif t=="1":
61
+ t="il testo è neutro rispetto alla vaccinazione"
62
+ elif t=="2":
63
+ t="is testo è sfavorevole alla vaccinazione"
64
+ return t
65
+
66
+ def to_inputs_and_targets(ex):
67
+ """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
68
+ return {
69
+ "inputs":
70
+ tf.strings.join(
71
+ [normalize_text(ex["source"])]),
72
+ "targets":
73
+ tf.strings.join(
74
+ [fulltext(normalize_text(ex["target"]))]),
75
+ }
76
+ return ds.map(to_inputs_and_targets,
77
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
78
+
79
+
80
+ def categorise_fulltext_word_preprocessor(ds):
81
+ def normalize_text(text):
82
+ """Lowercase and remove quotes from a TensorFlow string."""
83
+ text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
84
+ return text
85
+
86
+ def fulltext(t):
87
+ if t=="0":
88
+ t="promozionale"
89
+ elif t=="1":
90
+ t="neutro"
91
+ elif t=="2":
92
+ t="scoraggiante"
93
+ return t
94
+
95
+ def to_inputs_and_targets(ex):
96
+ """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
97
+ return {
98
+ "inputs":
99
+ tf.strings.join(
100
+ [normalize_text(ex["source"])]),
101
+ "targets":
102
+ tf.strings.join(
103
+ [fulltext(normalize_text(ex["target"]))]),
104
+ }
105
+ return ds.map(to_inputs_and_targets,
106
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
107
+
108
+
109
+
110
+ def categorise_binary_preprocessor(ds):
111
+ def normalize_text(text):
112
+ """Lowercase and remove quotes from a TensorFlow string."""
113
+ text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
114
+ return text
115
+
116
+ def fulltext(t):
117
+ if t=="0":
118
+ t="1"
119
+ elif t=="1":
120
+ t="1"
121
+ elif t=="2":
122
+ t="2"
123
+ return t
124
+
125
+ def to_inputs_and_targets(ex):
126
+ """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
127
+ return {
128
+ "inputs":
129
+ tf.strings.join(
130
+ [normalize_text(ex["source"])]),
131
+ "targets":
132
+ tf.strings.join(
133
+ [fulltext(normalize_text(ex["target"]))]),
134
+ }
135
+ return ds.map(to_inputs_and_targets,
136
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
137
+
138
+
139
+
140
+ seqio.TaskRegistry.add(
141
+ "classify_tweets",
142
+ source=seqio.TextLineDataSource(
143
+ split_to_filepattern=tsv_path,
144
+ #num_input_examples=num_nq_examples
145
+ ),
146
+ preprocessors=[
147
+ functools.partial(
148
+ t5.data.preprocessors.parse_tsv,
149
+ field_names=["annotator1","annotator2","annotator3","target","source","id"]),
150
+ categorise_preprocessor,
151
+ seqio.preprocessors.tokenize_and_append_eos,
152
+ ],
153
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
154
+ output_features=DEFAULT_OUTPUT_FEATURES,
155
+ )
156
+
157
+ seqio.TaskRegistry.add(
158
+ "classify_tweets_fulltext",
159
+ source=seqio.TextLineDataSource(
160
+ split_to_filepattern=tsv_path,
161
+ #num_input_examples=num_nq_examples
162
+ ),
163
+ preprocessors=[
164
+ functools.partial(
165
+ t5.data.preprocessors.parse_tsv,
166
+ field_names=["annotator1","annotator2","annotator3","target","source","id"]),
167
+ categorise_fulltext_preprocessor,
168
+ seqio.preprocessors.tokenize_and_append_eos,
169
+ ],
170
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
171
+ output_features=DEFAULT_OUTPUT_FEATURES,
172
+ )
173
+
174
+ seqio.TaskRegistry.add(
175
+ "classify_tweets_binary",
176
+ source=seqio.TextLineDataSource(
177
+ split_to_filepattern=tsv_path,
178
+ #num_input_examples=num_nq_examples
179
+ ),
180
+ preprocessors=[
181
+ functools.partial(
182
+ t5.data.preprocessors.parse_tsv,
183
+ field_names=["annotator1","annotator2","annotator3","target","source","id"]),
184
+ categorise_binary_preprocessor,
185
+ seqio.preprocessors.tokenize_and_append_eos,
186
+ ],
187
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
188
+ output_features=DEFAULT_OUTPUT_FEATURES,
189
+ )
190
+
191
+ seqio.TaskRegistry.add(
192
+ "classify_tweets_fulltext_word",
193
+ source=seqio.TextLineDataSource(
194
+ split_to_filepattern=tsv_path,
195
+ #num_input_examples=num_nq_examples
196
+ ),
197
+ preprocessors=[
198
+ functools.partial(
199
+ t5.data.preprocessors.parse_tsv,
200
+ field_names=["annotator1","annotator2","annotator3","target","source","id"]),
201
+ categorise_fulltext_word_preprocessor,
202
+ seqio.preprocessors.tokenize_and_append_eos,
203
+ ],
204
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
205
+ output_features=DEFAULT_OUTPUT_FEATURES,
206
+ )
207
+