pere commited on
Commit
5e4d91f
1 Parent(s): ba171ad

first commit

Browse files
Files changed (4) hide show
  1. finetune_base.sh +10 -0
  2. finetune_translate_base.gin +39 -0
  3. my_metrics.py +7 -0
  4. tasks.py +183 -0
finetune_base.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
+ export PYTHONPATH=${PROJECT_DIR}
3
+
4
+ python3 ../../t5x/t5x/train.py \
5
+ --gin_search_paths="./" \
6
+ --gin_file="finetune_translate_base.gin" \
7
+ --gin.MIXTURE_OR_TASK_NAME=\"translate\" \
8
+ --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/v1_nynorsk_translate_mT5_base\" \
9
+
10
+
finetune_translate_base.gin ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import tasks
3
+
4
+ import __main__ as train_script
5
+ from t5.data import mixtures
6
+ from t5x import models
7
+ from t5x import partitioning
8
+ from t5x import utils
9
+
10
+ include "t5x/examples/t5/mt5/base.gin"
11
+ include "t5x/configs/runs/finetune.gin"
12
+
13
+ MIXTURE_OR_TASK_NAME = %gin.REQUIRED
14
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
15
+ INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_base/checkpoint_1000000"
16
+ TRAIN_STEPS = 1_010_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
17
+ USE_CACHED_TASKS = False
18
+ DROPOUT_RATE = 0.1
19
+ RANDOM_SEED = 0
20
+
21
+ #Fixing a small error
22
+ infer_eval/utils.DatasetConfig:
23
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
24
+
25
+ #Saving every 1000 steps
26
+ utils.SaveCheckpointConfig:
27
+ period = 1000
28
+
29
+
30
+ # Pere: Only necessary if we load a t5 model. We can start with an t5x model here
31
+ # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained
32
+ # using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be
33
+ # set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1:
34
+ # `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
35
+ # LOSS_NORMALIZING_FACTOR = 234496
36
+
37
+ # Might have to ba changed based on architecture
38
+ # partitioning.PjitPartitioner.num_partitions = 1
39
+
my_metrics.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sklearn.metrics
2
+ import numpy as np
3
+
4
+ def f1_macro(targets, predictions):
5
+ targets, predictions = np.asarray(targets).astype(str), np.asarray(predictions).astype(str)
6
+ return {"f1_macro": 100*sklearn.metrics.f1_score(targets, predictions, average='macro')}
7
+
tasks.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /home/perk/mymodel/categorisation-mt5x/tasks.py
2
+
3
+ import functools
4
+ import seqio
5
+ import my_metrics
6
+ import tensorflow_datasets as tfds
7
+ from t5.evaluation import metrics
8
+ from t5.data import preprocessors
9
+ #import my_preprocessors
10
+ import t5
11
+ import tensorflow.compat.v1 as tf
12
+
13
+
14
+
15
+ tsv_parliament_path = {
16
+ "train": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/train.tsv",
17
+ "validation": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/dev.tsv",
18
+ "test": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/test.tsv"
19
+ }
20
+
21
+ tsv_translate_path = {
22
+ "train": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/train.tsv",
23
+ "validation": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/dev.tsv",
24
+ "test": "gs://nb-t5x-us-central2/corpus_bokmal_nynorsk/test.tsv"
25
+ }
26
+
27
+ tsv_sentiment_path = {
28
+ "train": "gs://notram-public/finetune_datasets/norec_sentiment/train.tsv",
29
+ "validation": "gs://notram-public/finetune_datasets/norec_sentiment/dev.tsv",
30
+ "test": "gs://notram-public/finetune_datasets/norec_sentiment/test.tsv"
31
+ }
32
+
33
+ json_angry_tweets_path = {
34
+ "train": "gs://notram-public/finetune_datasets/angry_tweets/train.jsonl",
35
+ "validation": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl",
36
+ "test": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl"
37
+ }
38
+
39
+ tsv_angry_tweets_path = {
40
+ "train": "gs://notram-public/finetune_datasets/angry_tweets/train.tsv",
41
+ "validation": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv",
42
+ "test": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv"
43
+ }
44
+
45
+
46
+ tsv_dane_path = {
47
+ "train": "gs://notram-public/finetune_datasets/dane/train.tsv",
48
+ "validation": "gs://notram-public/finetune_datasets/dane/test.tsv",
49
+ "test": "gs://notram-public/finetune_datasets/dane/test.tsv"
50
+ }
51
+
52
+ tsv_dane_tokens_path = {
53
+ "train": "gs://notram-public/finetune_datasets/dane/train_tokens.tsv",
54
+ "validation": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv",
55
+ "test": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv"
56
+ }
57
+
58
+
59
+ tsv_dane_long_tokens_path = {
60
+ "train": "gs://notram-public/finetune_datasets/dane/train_long_tokens.tsv",
61
+ "validation": "gs://notram-public/finetune_datasets/dane/test_long_tokens.tsv",
62
+ "test": "gs://notram-public/finetune_datasets/dane/test_long_tokens.tsv"
63
+ }
64
+
65
+
66
+ vocabulary = seqio.SentencePieceVocabulary(
67
+ 'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
68
+
69
+ DEFAULT_OUTPUT_FEATURES = {
70
+ "inputs":
71
+ seqio.Feature(
72
+ vocabulary=vocabulary, add_eos=True),
73
+ "targets":
74
+ seqio.Feature(
75
+ vocabulary=vocabulary, add_eos=True)
76
+ }
77
+
78
+ def categorise_preprocessor(ds):
79
+ def normalize_text(text):
80
+ """Lowercase and remove quotes from a TensorFlow string."""
81
+ #text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
82
+ ...
83
+ return text
84
+
85
+ def to_inputs_and_targets(ex):
86
+ """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
87
+ return {
88
+ "inputs":
89
+ tf.strings.join(
90
+ [normalize_text(ex["source"])]),
91
+ "targets":
92
+ tf.strings.join(
93
+ [normalize_text(ex["target"])]),
94
+ }
95
+ return ds.map(to_inputs_and_targets,
96
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
97
+
98
+
99
+ seqio.TaskRegistry.add(
100
+ "parliament",
101
+ source=seqio.TextLineDataSource(
102
+ split_to_filepattern=tsv_parliament_path,
103
+ #num_input_examples=num_nq_examples
104
+ ),
105
+ preprocessors=[
106
+ functools.partial(
107
+ t5.data.preprocessors.parse_tsv,
108
+ field_names=["target","source"]),
109
+ categorise_preprocessor,
110
+ seqio.preprocessors.tokenize_and_append_eos,
111
+ ],
112
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
113
+ output_features=DEFAULT_OUTPUT_FEATURES,
114
+ )
115
+
116
+ seqio.TaskRegistry.add(
117
+ "sentiment",
118
+ source=seqio.TextLineDataSource(
119
+ split_to_filepattern=tsv_sentiment_path,
120
+ #num_input_examples=num_nq_examples
121
+ ),
122
+ preprocessors=[
123
+ functools.partial(
124
+ t5.data.preprocessors.parse_tsv,
125
+ field_names=["target","source"]),
126
+ categorise_preprocessor,
127
+ seqio.preprocessors.tokenize_and_append_eos,
128
+ ],
129
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
130
+ output_features=DEFAULT_OUTPUT_FEATURES,
131
+ )
132
+
133
+ seqio.TaskRegistry.add(
134
+ "angry_tweets",
135
+ source=seqio.TextLineDataSource(
136
+ split_to_filepattern=tsv_angry_tweets_path,
137
+ #num_input_examples=num_nq_examples
138
+ ),
139
+ preprocessors=[
140
+ functools.partial(
141
+ t5.data.preprocessors.parse_tsv,
142
+ field_names=["target","source"]),
143
+ categorise_preprocessor,
144
+ seqio.preprocessors.tokenize_and_append_eos,
145
+ ],
146
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
147
+ output_features=DEFAULT_OUTPUT_FEATURES,
148
+ )
149
+
150
+ seqio.TaskRegistry.add(
151
+ "dane",
152
+ source=seqio.TextLineDataSource(
153
+ split_to_filepattern=tsv_dane_long_tokens_path,
154
+ #num_input_examples=num_nq_examples
155
+ ),
156
+ preprocessors=[
157
+ functools.partial(
158
+ t5.data.preprocessors.parse_tsv,
159
+ field_names=["placeholder1","placeholder2","placeholder3","target","source"]),
160
+ categorise_preprocessor,
161
+ seqio.preprocessors.tokenize_and_append_eos,
162
+ ],
163
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
164
+ output_features=DEFAULT_OUTPUT_FEATURES,
165
+ )
166
+
167
+ seqio.TaskRegistry.add(
168
+ "translate",
169
+ source=seqio.TextLineDataSource(
170
+ split_to_filepattern=tsv_translate_path,
171
+ #num_input_examples=num_nq_examples
172
+ ),
173
+ preprocessors=[
174
+ functools.partial(
175
+ t5.data.preprocessors.parse_tsv,
176
+ field_names=["source","target"]),
177
+ categorise_preprocessor,
178
+ seqio.preprocessors.tokenize_and_append_eos,
179
+ ],
180
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
181
+ output_features=DEFAULT_OUTPUT_FEATURES,
182
+ )
183
+