pere commited on
Commit
4d857d2
1 Parent(s): dea5f5f
batch_finetune_eu_jav_base_exp_english.sh CHANGED
@@ -1,6 +1,6 @@
1
  PROJECT_DIR=${HOME}"/models/eu-jav-categorisation"
2
  export PYTHONPATH=${PROJECT_DIR}
3
- INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_basecheckpoint_1000000\"
4
  TRAIN_STEPS=1005000
5
 
6
  python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base_exp_english.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/english_classify_tweets_base_v1\" &&
 
1
  PROJECT_DIR=${HOME}"/models/eu-jav-categorisation"
2
  export PYTHONPATH=${PROJECT_DIR}
3
+ INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000\"
4
  TRAIN_STEPS=1005000
5
 
6
  python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base_exp_english.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/english_classify_tweets_base_v1\" &&
batch_finetune_eu_jav_base_exp_fulltext.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/eu-jav-categorisation"
2
+ export PYTHONPATH=${PROJECT_DIR}
3
+ INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/mt5_base/checkpoint_1000000\"
4
+ TRAIN_STEPS=1005000
5
+
6
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_fulltext\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_classify_tweets_fulltext_base_v1\" &&
7
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_fulltext\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_classify_tweets_fulltext_base_v2\" &&
8
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_fulltext\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_classify_tweets_fulltext_base_v3\" &&
9
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_fulltext\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_classify_tweets_fulltext_base_v4\" &&
10
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets_fulltext\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/fulltext_classify_tweets_fulltext_base_v5\"
11
+
finetune_classification_base.gin CHANGED
@@ -11,7 +11,7 @@ include "t5x/examples/t5/mt5/base.gin"
11
  include "t5x/configs/runs/finetune.gin"
12
 
13
  MIXTURE_OR_TASK_NAME = %gin.REQUIRED
14
- TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 2}
15
  INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
16
  TRAIN_STEPS = %gin.REQUIRED # 1000000 pre-trained steps + 10000 fine-tuning steps.
17
  USE_CACHED_TASKS = False
 
11
  include "t5x/configs/runs/finetune.gin"
12
 
13
  MIXTURE_OR_TASK_NAME = %gin.REQUIRED
14
+ TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 32}
15
  INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
16
  TRAIN_STEPS = %gin.REQUIRED # 1000000 pre-trained steps + 10000 fine-tuning steps.
17
  USE_CACHED_TASKS = False
finetune_classification_base_exp_english.gin CHANGED
@@ -7,7 +7,7 @@ from t5x import models
7
  from t5x import partitioning
8
  from t5x import utils
9
 
10
- include "t5x/examples/t5/mt5/base.gin"
11
  include "t5x/configs/runs/finetune.gin"
12
 
13
  MIXTURE_OR_TASK_NAME = %gin.REQUIRED
@@ -24,7 +24,7 @@ infer_eval/utils.DatasetConfig:
24
 
25
  #Saving every 1000 steps
26
  utils.SaveCheckpointConfig:
27
- period = 100
28
 
29
 
30
  # Pere: Only necessary if we load a t5 model. We can start with an t5x model here
 
7
  from t5x import partitioning
8
  from t5x import utils
9
 
10
+ include "t5x/examples/t5/t5_1_1/base.gin"
11
  include "t5x/configs/runs/finetune.gin"
12
 
13
  MIXTURE_OR_TASK_NAME = %gin.REQUIRED
 
24
 
25
  #Saving every 1000 steps
26
  utils.SaveCheckpointConfig:
27
+ period = 500
28
 
29
 
30
  # Pere: Only necessary if we load a t5 model. We can start with an t5x model here
tasks.py CHANGED
@@ -43,6 +43,36 @@ def categorise_preprocessor(ds):
43
  tf.strings.join(
44
  [normalize_text(ex["target"])]),
45
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  return ds.map(to_inputs_and_targets,
47
  num_parallel_calls=tf.data.experimental.AUTOTUNE)
48
 
@@ -65,7 +95,7 @@ seqio.TaskRegistry.add(
65
  )
66
 
67
  seqio.TaskRegistry.add(
68
- "classify_tweets_a1",
69
  source=seqio.TextLineDataSource(
70
  split_to_filepattern=tsv_path,
71
  #num_input_examples=num_nq_examples
@@ -73,44 +103,11 @@ seqio.TaskRegistry.add(
73
  preprocessors=[
74
  functools.partial(
75
  t5.data.preprocessors.parse_tsv,
76
- field_names=["target","annotator2","annotator3","placeholder","source","id"]),
77
- categorise_preprocessor,
78
- seqio.preprocessors.tokenize_and_append_eos,
79
- ],
80
- metric_fns=[metrics.accuracy,my_metrics.f1_macro],
81
- output_features=DEFAULT_OUTPUT_FEATURES,
82
- )
83
-
84
- seqio.TaskRegistry.add(
85
- "classify_tweets_a2",
86
- source=seqio.TextLineDataSource(
87
- split_to_filepattern=tsv_path,
88
- #num_input_examples=num_nq_examples
89
- ),
90
- preprocessors=[
91
- functools.partial(
92
- t5.data.preprocessors.parse_tsv,
93
- field_names=["annotator1","target","annotator3","placeholder","source","id"]),
94
- categorise_preprocessor,
95
  seqio.preprocessors.tokenize_and_append_eos,
96
  ],
97
  metric_fns=[metrics.accuracy,my_metrics.f1_macro],
98
  output_features=DEFAULT_OUTPUT_FEATURES,
99
  )
100
 
101
- seqio.TaskRegistry.add(
102
- "classify_tweets_a3",
103
- source=seqio.TextLineDataSource(
104
- split_to_filepattern=tsv_path,
105
- #num_input_examples=num_nq_examples
106
- ),
107
- preprocessors=[
108
- functools.partial(
109
- t5.data.preprocessors.parse_tsv,
110
- field_names=["annotator1","annotator2","target","placeholder","source","id"]),
111
- categorise_preprocessor,
112
- seqio.preprocessors.tokenize_and_append_eos,
113
- ],
114
- metric_fns=[metrics.accuracy,my_metrics.f1_macro],
115
- output_features=DEFAULT_OUTPUT_FEATURES,
116
- )
 
43
  tf.strings.join(
44
  [normalize_text(ex["target"])]),
45
  }
46
+
47
+ return ds.map(to_inputs_and_targets,
48
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
49
+
50
+
51
+ def categorise_fulltext_preprocessor(ds):
52
+ def normalize_text(text):
53
+ """Lowercase and remove quotes from a TensorFlow string."""
54
+ text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
55
+ return text
56
+
57
+ def fulltext(t):
58
+ if t=="0":
59
+ t="il testo è favorevole alla vaccinazione"
60
+ elif t=="1":
61
+ t="il testo è neutro rispetto alla vaccinazione"
62
+ elif t=="3"
63
+ t="is testo è sfavorevole alla vaccinazione"
64
+ return t
65
+
66
+ def to_inputs_and_targets(ex):
67
+ """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
68
+ return {
69
+ "inputs":
70
+ tf.strings.join(
71
+ [normalize_text(ex["source"])]),
72
+ "targets":
73
+ tf.strings.join(
74
+ [fulltext(normalize_text(ex["target"]))]),
75
+ }
76
  return ds.map(to_inputs_and_targets,
77
  num_parallel_calls=tf.data.experimental.AUTOTUNE)
78
 
 
95
  )
96
 
97
  seqio.TaskRegistry.add(
98
+ "classify_tweetsi_fulltext",
99
  source=seqio.TextLineDataSource(
100
  split_to_filepattern=tsv_path,
101
  #num_input_examples=num_nq_examples
 
103
  preprocessors=[
104
  functools.partial(
105
  t5.data.preprocessors.parse_tsv,
106
+ field_names=["annotator1","annotator2","annotator3","target","source","id"]),
107
+ categorise_fulltext:preprocessor,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  seqio.preprocessors.tokenize_and_append_eos,
109
  ],
110
  metric_fns=[metrics.accuracy,my_metrics.f1_macro],
111
  output_features=DEFAULT_OUTPUT_FEATURES,
112
  )
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tasks_exp_english.py CHANGED
@@ -16,7 +16,7 @@ tsv_path = {
16
  }
17
 
18
  vocabulary = seqio.SentencePieceVocabulary(
19
- 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model', extra_ids=0)
20
 
21
  DEFAULT_OUTPUT_FEATURES = {
22
  "inputs":
@@ -63,54 +63,3 @@ seqio.TaskRegistry.add(
63
  metric_fns=[metrics.accuracy,my_metrics.f1_macro],
64
  output_features=DEFAULT_OUTPUT_FEATURES,
65
  )
66
-
67
- seqio.TaskRegistry.add(
68
- "classify_tweets_a1",
69
- source=seqio.TextLineDataSource(
70
- split_to_filepattern=tsv_path,
71
- #num_input_examples=num_nq_examples
72
- ),
73
- preprocessors=[
74
- functools.partial(
75
- t5.data.preprocessors.parse_tsv,
76
- field_names=["target","annotator2","annotator3","placeholder","source","id"]),
77
- categorise_preprocessor,
78
- seqio.preprocessors.tokenize_and_append_eos,
79
- ],
80
- metric_fns=[metrics.accuracy,my_metrics.f1_macro],
81
- output_features=DEFAULT_OUTPUT_FEATURES,
82
- )
83
-
84
- seqio.TaskRegistry.add(
85
- "classify_tweets_a2",
86
- source=seqio.TextLineDataSource(
87
- split_to_filepattern=tsv_path,
88
- #num_input_examples=num_nq_examples
89
- ),
90
- preprocessors=[
91
- functools.partial(
92
- t5.data.preprocessors.parse_tsv,
93
- field_names=["annotator1","target","annotator3","placeholder","source","id"]),
94
- categorise_preprocessor,
95
- seqio.preprocessors.tokenize_and_append_eos,
96
- ],
97
- metric_fns=[metrics.accuracy,my_metrics.f1_macro],
98
- output_features=DEFAULT_OUTPUT_FEATURES,
99
- )
100
-
101
- seqio.TaskRegistry.add(
102
- "classify_tweets_a3",
103
- source=seqio.TextLineDataSource(
104
- split_to_filepattern=tsv_path,
105
- #num_input_examples=num_nq_examples
106
- ),
107
- preprocessors=[
108
- functools.partial(
109
- t5.data.preprocessors.parse_tsv,
110
- field_names=["annotator1","annotator2","target","placeholder","source","id"]),
111
- categorise_preprocessor,
112
- seqio.preprocessors.tokenize_and_append_eos,
113
- ],
114
- metric_fns=[metrics.accuracy,my_metrics.f1_macro],
115
- output_features=DEFAULT_OUTPUT_FEATURES,
116
- )
 
16
  }
17
 
18
  vocabulary = seqio.SentencePieceVocabulary(
19
+ 'gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model', extra_ids=0)
20
 
21
  DEFAULT_OUTPUT_FEATURES = {
22
  "inputs":
 
63
  metric_fns=[metrics.accuracy,my_metrics.f1_macro],
64
  output_features=DEFAULT_OUTPUT_FEATURES,
65
  )