pere commited on
Commit
4a04574
1 Parent(s): 0ca917b

larger models added

Browse files
__pycache__/tasks.cpython-38.pyc CHANGED
Binary files a/__pycache__/tasks.cpython-38.pyc and b/__pycache__/tasks.cpython-38.pyc differ
 
finetuning_categorisation_xl.gin ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import tasks
3
+
4
+ import __main__ as train_script
5
+ from t5.data import mixtures
6
+ from t5x import models
7
+ from t5x import partitioning
8
+ from t5x import utils
9
+
10
+ include "t5x/examples/t5/mt5/xl.gin"
11
+ include "t5x/configs/runs/finetune.gin"
12
+
13
+ MIXTURE_OR_TASK_NAME = "categorise"
14
+ TASK_FEATURE_LENGTHS = {"inputs": 96, "targets": 2}
15
+ TRAIN_STEPS = 1_010_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
16
+ USE_CACHED_TASKS = False
17
+ DROPOUT_RATE = 0.0
18
+ RANDOM_SEED = 0
19
+ BATCH_SIZE = 8
20
+
21
+
22
+ # Pere: Only necessary if we load a t5 model. We can start with an t5x model here
23
+ # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained
24
+ # using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be
25
+ # set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1:
26
+ # `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
27
+ #LOSS_NORMALIZING_FACTOR = 234496
28
+
29
+ INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_xl/checkpoint_1000000"
30
+
31
+ #train_script.train:
32
+ # eval_period = 500
33
+ # partitioner = @partitioning.ModelBasedPjitPartitioner()
34
+ # partitioning.ModelBasedPjitPartitioner.num_partitions = 2
35
+
36
+ # `num_decodes` is equivalent to a beam size in a beam search decoding.
37
+ models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
38
+
39
+ #mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
40
+ #run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
41
+
finetuning_categorisation_xxl.gin ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import tasks
3
+
4
+ import __main__ as train_script
5
+ from t5.data import mixtures
6
+ from t5x import models
7
+ from t5x import partitioning
8
+ from t5x import utils
9
+
10
+ include "t5x/examples/t5/mt5/xxl.gin"
11
+ include "t5x/configs/runs/finetune.gin"
12
+
13
+ MIXTURE_OR_TASK_NAME = "categorise"
14
+ TASK_FEATURE_LENGTHS = {"inputs": 96, "targets": 2}
15
+ TRAIN_STEPS = 1_010_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
16
+ USE_CACHED_TASKS = False
17
+ DROPOUT_RATE = 0.0
18
+ RANDOM_SEED = 0
19
+
20
+ # Pere: Only necessary if we load a t5 model. We can start with an t5x model here
21
+ # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained
22
+ # using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be
23
+ # set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1:
24
+ # `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
25
+ #LOSS_NORMALIZING_FACTOR = 234496
26
+
27
+ INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_xxl/checkpoint_1000000"
28
+
29
+ #train_script.train:
30
+ # eval_period = 500
31
+ # partitioner = @partitioning.ModelBasedPjitPartitioner()
32
+ # partitioning.ModelBasedPjitPartitioner.num_partitions = 2
33
+
34
+ # `num_decodes` is equivalent to a beam size in a beam search decoding.
35
+ models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
36
+
37
+ #mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
38
+ #run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
39
+
tasks.py CHANGED
@@ -9,9 +9,9 @@ import t5
9
  import tensorflow.compat.v1 as tf
10
 
11
  tsv_path = {
12
- "train": "gs://peregilcloud/italian_tweets/train2.tsv",
13
- "validation": "gs://peregilcloud/italian_tweets/dev2.tsv",
14
- "test": "gs://peregilcloud/italian_tweets/test2.tsv"
15
  }
16
 
17
  vocabulary = seqio.SentencePieceVocabulary(
 
9
  import tensorflow.compat.v1 as tf
10
 
11
  tsv_path = {
12
+ "train": "gs://peregilcloud/italian_tweets/train3.tsv",
13
+ "validation": "gs://peregilcloud/italian_tweets/dev.tsv",
14
+ "test": "gs://peregilcloud/italian_tweets/test.tsv"
15
  }
16
 
17
  vocabulary = seqio.SentencePieceVocabulary(
train_large.sh CHANGED
@@ -1,7 +1,7 @@
1
  PROJECT_DIR=${HOME}"/models/eu-jav-categorisation"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
  #Needs to be updated when moving to tpu-v4 it should then be in another zone
4
- MODEL_DIR="gs://nb-t5x/eujav_large2"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \
 
1
  PROJECT_DIR=${HOME}"/models/eu-jav-categorisation"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
  #Needs to be updated when moving to tpu-v4 it should then be in another zone
4
+ MODEL_DIR="gs://nb-t5x/eujav_large3"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \
train_xl.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/eu-jav-categorisation"
2
+ T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ #Needs to be updated when moving to tpu-v4 it should then be in another zone
4
+ MODEL_DIR="gs://nb-t5x/eujav_xl"
5
+ export PYTHONPATH=${PROJECT_DIR}
6
+
7
+ python3 ${T5X_DIR}/t5x/train.py \
8
+ --gin_search_paths=${PROJECT_DIR} \
9
+ --gin_file="finetuning_categorisation_xl.gin" \
10
+ --gin.MODEL_DIR="'${MODEL_DIR}'"
11
+
train_xxl.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/eu-jav-categorisation"
2
+ T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ #Needs to be updated when moving to tpu-v4 it should then be in another zone
4
+ MODEL_DIR="gs://nb-t5x/eujav_xxl"
5
+ export PYTHONPATH=${PROJECT_DIR}
6
+
7
+ python3 ${T5X_DIR}/t5x/train.py \
8
+ --gin_search_paths=${PROJECT_DIR} \
9
+ --gin_file="finetuning_categorisation_xxl.gin" \
10
+ --gin.MODEL_DIR="'${MODEL_DIR}'"
11
+