pere commited on
Commit
e957cc1
1 Parent(s): ad029a3

updated dataset

Browse files
__pycache__/my_metrics.cpython-38.pyc ADDED
Binary file (459 Bytes). View file
 
__pycache__/tasks.cpython-38.pyc ADDED
Binary file (4.79 kB). View file
 
finetune_base.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/north-t5-base-deuncaser"
2
+ export PYTHONPATH=${PROJECT_DIR}
3
+ INITIAL_CHECKPOINT_PATH=\"gs://north-t5x/pretrained_models/base/norwegian_NCC_plus_English_pluss200k_balanced_bokmaal_nynorsk_t5x_base/checkpoint_1700000\"
4
+ TRAIN_STEPS=1800000
5
+ python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_deuncaser_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"deuncaser\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/deuncaser/deuncaser_base_v1\"
6
+
7
+ #python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/ts20_classify_tweets_base_v2\" &&
8
+ #python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/ts20_classify_tweets_base_v3\" &&
9
+ #python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/ts20_classify_tweets_base_v4\" &&
10
+ #python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_classification_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"classify_tweets\" --gin.MODEL_DIR=\"gs://eu-jav-t5x/finetuned/italian_tweets/ts_20classify_tweets_base_v5\"
11
+
finetune_deuncaser_base.gin CHANGED
@@ -22,9 +22,9 @@ RANDOM_SEED = 0
22
  infer_eval/utils.DatasetConfig:
23
  task_feature_lengths = %TASK_FEATURE_LENGTHS
24
 
25
- #Saving every 1000 steps
26
  utils.SaveCheckpointConfig:
27
- period = 1000
28
 
29
 
30
  # Pere: Only necessary if we load a t5 model. We can start with an t5x model here
 
22
  infer_eval/utils.DatasetConfig:
23
  task_feature_lengths = %TASK_FEATURE_LENGTHS
24
 
25
+ #Saving every 10000 steps
26
  utils.SaveCheckpointConfig:
27
+ period = 10000
28
 
29
 
30
  # Pere: Only necessary if we load a t5 model. We can start with an t5x model here
tasks.py CHANGED
@@ -10,9 +10,9 @@ import t5
10
  import tensorflow.compat.v1 as tf
11
 
12
  tsv_path = {
13
- "train": "gs://eu-jav-t5x/corpus/labeled/datasetA_train_3categories.tsv",
14
- "validation": "gs://eu-jav-t5x/corpus/labeled/datasetA_dev_3categories.tsv",
15
- "test": "gs://eu-jav-t5x/corpus/labeled/ datasetA_test_3categories.tsv"
16
  }
17
 
18
  vocabulary = seqio.SentencePieceVocabulary(
@@ -138,7 +138,7 @@ def categorise_binary_preprocessor(ds):
138
 
139
 
140
  seqio.TaskRegistry.add(
141
- "classify_tweets",
142
  source=seqio.TextLineDataSource(
143
  split_to_filepattern=tsv_path,
144
  #num_input_examples=num_nq_examples
@@ -146,11 +146,11 @@ seqio.TaskRegistry.add(
146
  preprocessors=[
147
  functools.partial(
148
  t5.data.preprocessors.parse_tsv,
149
- field_names=["annotator1","annotator2","annotator3","target","source","id"]),
150
  categorise_preprocessor,
151
  seqio.preprocessors.tokenize_and_append_eos,
152
  ],
153
- metric_fns=[metrics.accuracy,my_metrics.f1_macro],
154
  output_features=DEFAULT_OUTPUT_FEATURES,
155
  )
156
 
 
10
  import tensorflow.compat.v1 as tf
11
 
12
  tsv_path = {
13
+ "train": "gs://north-t5x/corpus/deuncaser/norwegian/train.tsv",
14
+ "validation": "gs://north-t5x/corpus/deuncaser/norwegian/validation.tsv",
15
+ "test": "gs://north-t5x/corpus/deuncaser/norwegian/validation.tsv"
16
  }
17
 
18
  vocabulary = seqio.SentencePieceVocabulary(
 
138
 
139
 
140
  seqio.TaskRegistry.add(
141
+ "deuncaser",
142
  source=seqio.TextLineDataSource(
143
  split_to_filepattern=tsv_path,
144
  #num_input_examples=num_nq_examples
 
146
  preprocessors=[
147
  functools.partial(
148
  t5.data.preprocessors.parse_tsv,
149
+ field_names=["id","methods","source","target"]),
150
  categorise_preprocessor,
151
  seqio.preprocessors.tokenize_and_append_eos,
152
  ],
153
+ metric_fns=[metrics.accuracy,metrics.bleu],
154
  output_features=DEFAULT_OUTPUT_FEATURES,
155
  )
156