pere commited on
Commit
a46d466
1 Parent(s): a8858d7
Files changed (2) hide show
  1. batch_finetune_eu_jav_large_binary.sh +11 -0
  2. tasks.py +21 -3
batch_finetune_eu_jav_large_binary.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/eu-jav-categorisation"
2
+ export PYTHONPATH=${PROJECT_DIR}
3
+ INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000\"
4
+ TRAIN_STEPS=1002000
5
+
6
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_large.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_binary\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/classify_tweets_large_binary_v1\" &&
7
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_large.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_binary\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/classify_tweets_largebinary__v2\" &&
8
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_large.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_binary\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/classify_tweets_large_binary_v3\" &&
9
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_large.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_binary\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/classify_tweets_large_binary_v4\" &&
10
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_large.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_binary\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/classify_tweets_large_binary_v5\"
11
+
tasks.py CHANGED
@@ -115,11 +115,11 @@ def categorise_binary_preprocessor(ds):
115
 
116
  def fulltext(t):
117
  if t=="0":
118
- t="1"
119
  elif t=="1":
120
- t="1"
121
  elif t=="2":
122
- t="2"
123
  return t
124
 
125
  def to_inputs_and_targets(ex):
@@ -137,6 +137,24 @@ def categorise_binary_preprocessor(ds):
137
 
138
 
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  seqio.TaskRegistry.add(
141
  "classify_tweets",
142
  source=seqio.TextLineDataSource(
 
115
 
116
  def fulltext(t):
117
  if t=="0":
118
+ t="0"
119
  elif t=="1":
120
+ t="0"
121
  elif t=="2":
122
+ t="1"
123
  return t
124
 
125
  def to_inputs_and_targets(ex):
 
137
 
138
 
139
 
140
+ seqio.TaskRegistry.add(
141
+ "classify_tweets_binary",
142
+ source=seqio.TextLineDataSource(
143
+ split_to_filepattern=tsv_path,
144
+ #num_input_examples=num_nq_examples
145
+ ),
146
+ preprocessors=[
147
+ functools.partial(
148
+ t5.data.preprocessors.parse_tsv,
149
+ field_names=["annotator1","annotator2","annotator3","target","source","id"]),
150
+ categorise_binary_preprocessor,
151
+ seqio.preprocessors.tokenize_and_append_eos,
152
+ ],
153
+ metric_fns=[metrics.accuracy,my_metrics.f1_macro],
154
+ output_features=DEFAULT_OUTPUT_FEATURES,
155
+ )
156
+
157
+
158
  seqio.TaskRegistry.add(
159
  "classify_tweets",
160
  source=seqio.TextLineDataSource(