pere commited on
Commit
91bc225
1 Parent(s): fabe95c
batch_finetune_eu_jav_base_exp_binary.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/eu-jav-categorisation"
2
+ export PYTHONPATH=${PROJECT_DIR}
3
+ INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/mt5_base/checkpoint_1000000\"
4
+ TRAIN_STEPS=1005000
5
+
6
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_binary\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_words_classify_tweets_binary_base_v1\" &&
7
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_binary\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_words_classify_tweets_binary_base_v2\" &&
8
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_binary\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_words_classify_tweets_binary_base_v3\" &&
9
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_binary\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_words_classify_tweets_binary_base_v4\" &&
10
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_binary\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_words_classify_tweets_binary_base_v5\"
11
+
tasks.py CHANGED
@@ -106,6 +106,37 @@ def categorise_fulltext_word_preprocessor(ds):
106
  num_parallel_calls=tf.data.experimental.AUTOTUNE)
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  seqio.TaskRegistry.add(
110
  "classify_tweets",
111
  source=seqio.TextLineDataSource(
@@ -140,6 +171,23 @@ seqio.TaskRegistry.add(
140
  output_features=DEFAULT_OUTPUT_FEATURES,
141
  )
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  seqio.TaskRegistry.add(
144
  "classify_tweets_fulltext_word",
145
  source=seqio.TextLineDataSource(
 
106
  num_parallel_calls=tf.data.experimental.AUTOTUNE)
107
 
108
 
109
+
110
+ def categorise_binary_preprocessor(ds):
111
+ def normalize_text(text):
112
+ """Lowercase and remove quotes from a TensorFlow string."""
113
+ text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
114
+ return text
115
+
116
+ def fulltext(t):
117
+ if t=="0":
118
+ t="1"
119
+ elif t=="1":
120
+ t="1"
121
+ elif t=="2":
122
+ t="2"
123
+ return t
124
+
125
+ def to_inputs_and_targets(ex):
126
+ """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
127
+ return {
128
+ "inputs":
129
+ tf.strings.join(
130
+ [normalize_text(ex["source"])]),
131
+ "targets":
132
+ tf.strings.join(
133
+ [fulltext(normalize_text(ex["target"]))]),
134
+ }
135
+ return ds.map(to_inputs_and_targets,
136
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
137
+
138
+
139
+
140
  seqio.TaskRegistry.add(
141
  "classify_tweets",
142
  source=seqio.TextLineDataSource(
 
171
  output_features=DEFAULT_OUTPUT_FEATURES,
172
  )
173
 
174
+ seqio.TaskRegistry.add(
175
+ "classify_tweets_binary",
176
+ source=seqio.TextLineDataSource(
177
+ split_to_filepattern=tsv_path,
178
+ #num_input_examples=num_nq_examples
179
+ ),
180
+ preprocessors=[
181
+ functools.partial(
182
+ t5.data.preprocessors.parse_tsv,
183
+ field_names=["annotator1","annotator2","annotator3","target","source","id"]),
184
+ categorise_fulltext_word_preprocessor,
185
+ seqio.preprocessors.tokenize_and_append_eos,
186
+ ],
187
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
188
+ output_features=DEFAULT_OUTPUT_FEATURES,
189
+ )
190
+
191
  seqio.TaskRegistry.add(
192
  "classify_tweets_fulltext_word",
193
  source=seqio.TextLineDataSource(