pere commited on
Commit
041a858
1 Parent(s): e73c2c5
batch_finetune_eu_jav_base_exp_fulltext_word.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/eu-jav-categorisation"
2
+ export PYTHONPATH=${PROJECT_DIR}
3
+ INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/mt5_base/checkpoint_1000000\"
4
+ TRAIN_STEPS=1005000
5
+
6
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_fulltext_word\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_words_classify_tweets_fulltext_words_base_v1\" &&
7
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_fulltext_word\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_words_classify_tweets_fulltext_words_base_v2\" &&
8
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_fulltext_word\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_words_classify_tweets_fulltext_words_base_v3\" &&
9
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_fulltext_word\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_words_classify_tweets_fulltext_words_base_v4\" &&
10
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_fulltext_word\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_words_classify_tweets_fulltext_words_base_v5\"
11
+
tasks.py CHANGED
@@ -59,7 +59,7 @@ def categorise_fulltext_preprocessor(ds):
59
  t="il testo è favorevole alla vaccinazione"
60
  elif t=="1":
61
  t="il testo è neutro rispetto alla vaccinazione"
62
- elif t=="3":
63
  t="is testo è sfavorevole alla vaccinazione"
64
  return t
65
 
@@ -77,6 +77,35 @@ def categorise_fulltext_preprocessor(ds):
77
  num_parallel_calls=tf.data.experimental.AUTOTUNE)
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  seqio.TaskRegistry.add(
81
  "classify_tweets",
82
  source=seqio.TextLineDataSource(
@@ -111,3 +140,20 @@ seqio.TaskRegistry.add(
111
  output_features=DEFAULT_OUTPUT_FEATURES,
112
  )
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  t="il testo è favorevole alla vaccinazione"
60
  elif t=="1":
61
  t="il testo è neutro rispetto alla vaccinazione"
62
+ elif t=="2":
63
  t="is testo è sfavorevole alla vaccinazione"
64
  return t
65
 
 
77
  num_parallel_calls=tf.data.experimental.AUTOTUNE)
78
 
79
 
80
+ def categorise_fulltext_word_preprocessor(ds):
81
+ def normalize_text(text):
82
+ """Lowercase and remove quotes from a TensorFlow string."""
83
+ text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
84
+ return text
85
+
86
+ def fulltext(t):
87
+ if t=="0":
88
+ t="promozionale"
89
+ elif t=="1":
90
+ t="neutro"
91
+ elif t=="2":
92
+ t="scoraggiante"
93
+ return t
94
+
95
+ def to_inputs_and_targets(ex):
96
+ """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
97
+ return {
98
+ "inputs":
99
+ tf.strings.join(
100
+ [normalize_text(ex["source"])]),
101
+ "targets":
102
+ tf.strings.join(
103
+ [fulltext(normalize_text(ex["target"]))]),
104
+ }
105
+ return ds.map(to_inputs_and_targets,
106
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
107
+
108
+
109
  seqio.TaskRegistry.add(
110
  "classify_tweets",
111
  source=seqio.TextLineDataSource(
 
140
  output_features=DEFAULT_OUTPUT_FEATURES,
141
  )
142
 
143
+ seqio.TaskRegistry.add(
144
+ "classify_tweets_fulltext_word",
145
+ source=seqio.TextLineDataSource(
146
+ split_to_filepattern=tsv_path,
147
+ #num_input_examples=num_nq_examples
148
+ ),
149
+ preprocessors=[
150
+ functools.partial(
151
+ t5.data.preprocessors.parse_tsv,
152
+ field_names=["annotator1","annotator2","annotator3","target","source","id"]),
153
+ categorise_fulltext_word_preprocessor,
154
+ seqio.preprocessors.tokenize_and_append_eos,
155
+ ],
156
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
157
+ output_features=DEFAULT_OUTPUT_FEATURES,
158
+ )
159
+