pere commited on
Commit
cb76020
1 Parent(s): 8fe9b93
Files changed (5) hide show
  1. __pycache__/tasks.cpython-38.pyc +0 -0
  2. base.gin +18 -0
  3. pretrain_cont.gin +111 -0
  4. tasks.py +89 -0
  5. train_base.sh +9 -0
__pycache__/tasks.cpython-38.pyc ADDED
Binary file (2.18 kB). View file
 
base.gin ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include 't5x/examples/t5/mt5/base.gin'
2
+ include 'pretrain_cont.gin'
3
+ #include 't5x/configs/runs/pretrain.gin'
4
+ #iinclude 't5x/configs/runs/finetune.gin'
5
+
6
+
7
+ # Register necessary SeqIO Tasks/Mixtures.
8
+ import t5.data.mixtures
9
+ import tasks
10
+
11
+ MIXTURE_OR_TASK_NAME = "ncc_scandinavian_span_corruption_stream"
12
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
13
+ TRAIN_STEPS = 1_700_000
14
+ DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
15
+ INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_base/checkpoint_1500000"
16
+ PjitPartitioner.num_partitions = 1
17
+
18
+
pretrain_cont.gin ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defaults for pretraining with train.py.
2
+ #
3
+ #
4
+ # You must also include a binding for MODEL.
5
+ #
6
+ # Required to be set
7
+ #
8
+ # - MIXTURE_OR_TASK_NAME
9
+ # - TASK_FEATURE_LENGTHS
10
+ # - TRAIN_STEPS - include pretrain steps
11
+ # - MODEL_DIR: # automatically set when using xm_launch
12
+ #
13
+ # Commonly overridden options:
14
+ #
15
+ # - train/DatasetConfig.batch_size
16
+ # - train_eval/DatasetConfig.batch_size
17
+ # - PjitPartitioner.num_partitions
18
+ # - Trainer.num_microbatches
19
+ # - DROPOUT_RATE
20
+ from __gin__ import dynamic_registration
21
+
22
+ import __main__ as train_script
23
+ from t5x import gin_utils
24
+ from t5x import partitioning
25
+ from t5x import utils
26
+ from t5x import trainer
27
+
28
+ MIXTURE_OR_TASK_NAME = %gin.REQUIRED
29
+ TASK_FEATURE_LENGTHS = %gin.REQUIRED
30
+ TRAIN_STEPS = %gin.REQUIRED
31
+ MODEL_DIR = %gin.REQUIRED
32
+ BATCH_SIZE = 128
33
+ USE_CACHED_TASKS = True
34
+ INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
35
+
36
+ # DEPRECATED: Import the this module in your gin file.
37
+ MIXTURE_OR_TASK_MODULE = None
38
+ SHUFFLE_TRAIN_EXAMPLES = True
39
+
40
+ # HW RNG is faster than SW, but has limited determinism.
41
+ # Most notably it is not deterministic across different
42
+ # submeshes.
43
+ USE_HARDWARE_RNG = False
44
+ # None always uses faster, hardware RNG
45
+ RANDOM_SEED = None
46
+
47
+ # Can be overridden with `train.*`.`
48
+ train_script.train:
49
+ model = %MODEL # imported from separate gin file
50
+ model_dir = %MODEL_DIR
51
+ train_dataset_cfg = @train/utils.DatasetConfig()
52
+ train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
53
+ infer_eval_dataset_cfg = None
54
+ checkpoint_cfg = @utils.CheckpointConfig()
55
+ partitioner = @partitioning.PjitPartitioner()
56
+ trainer_cls = @trainer.Trainer
57
+ total_steps = %TRAIN_STEPS
58
+ eval_steps = 20
59
+ eval_period = 1000
60
+ random_seed = %RANDOM_SEED
61
+ use_hardware_rng = %USE_HARDWARE_RNG
62
+ summarize_config_fn = @gin_utils.summarize_gin_config
63
+
64
+ partitioning.PjitPartitioner:
65
+ num_partitions = 1
66
+ model_parallel_submesh = None
67
+ logical_axis_rules = @partitioning.standard_logical_axis_rules()
68
+
69
+ train/utils.DatasetConfig:
70
+ mixture_or_task_name = %MIXTURE_OR_TASK_NAME
71
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
72
+ split = 'train'
73
+ batch_size = %BATCH_SIZE
74
+ shuffle = %SHUFFLE_TRAIN_EXAMPLES
75
+ seed = None # use a new seed each run/restart
76
+ use_cached = %USE_CACHED_TASKS
77
+ pack = True
78
+ module = %MIXTURE_OR_TASK_MODULE
79
+
80
+ train_eval/utils.DatasetConfig:
81
+ mixture_or_task_name = %MIXTURE_OR_TASK_NAME
82
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
83
+ split = 'validation'
84
+ batch_size = %BATCH_SIZE
85
+ shuffle = False
86
+ seed = 42
87
+ use_cached = %USE_CACHED_TASKS
88
+ pack = True
89
+ module = %MIXTURE_OR_TASK_MODULE
90
+
91
+ utils.CheckpointConfig:
92
+ restore = @utils.RestoreCheckpointConfig()
93
+ save = @utils.SaveCheckpointConfig()
94
+ utils.RestoreCheckpointConfig:
95
+ path = %INITIAL_CHECKPOINT_PATH
96
+ mode = 'specific'
97
+ dtype = 'float32'
98
+ utils.SaveCheckpointConfig:
99
+ period = 1000
100
+ dtype = 'float32'
101
+ keep = None # keep all checkpoints
102
+ save_dataset = False # don't checkpoint dataset state
103
+
104
+ trainer.Trainer:
105
+ num_microbatches = None
106
+ learning_rate_fn = @utils.create_learning_rate_scheduler()
107
+
108
+ utils.create_learning_rate_scheduler:
109
+ factors = 'constant * rsqrt_decay'
110
+ base_learning_rate = 0.5 #This is set to half of the original since it is continued training
111
+ warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults.
tasks.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import seqio
4
+ import tensorflow as tf
5
+ import t5.data
6
+ from datasets import load_dataset
7
+ from t5.data import postprocessors
8
+ from t5.data import preprocessors
9
+ from t5.evaluation import metrics
10
+ from seqio import FunctionDataSource, utils
11
+
12
+ TaskRegistry = seqio.TaskRegistry
13
+
14
+ vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
15
+ byt5_vocabulary = t5.data.ByteVocabulary()
16
+
17
+ DEFAULT_OUTPUT_FEATURES = {
18
+ "inputs": seqio.Feature(
19
+ vocabulary=vocabulary, add_eos=True,
20
+ required=False),
21
+ "targets": seqio.Feature(
22
+ vocabulary=vocabulary, add_eos=True)
23
+ }
24
+
25
+ BYT5_DEFAULT_OUTPUT_FEATURES = {
26
+ "inputs": seqio.Feature(
27
+ vocabulary=byt5_vocabulary, add_eos=True,
28
+ required=False),
29
+ "targets": seqio.Feature(
30
+ vocabulary=byt5_vocabulary, add_eos=True)
31
+ }
32
+
33
+
34
+ def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
35
+ dataset = load_dataset(**dataset_params)
36
+ if shuffle:
37
+ if seed:
38
+ dataset = dataset.shuffle(seed=seed)
39
+ else:
40
+ dataset = dataset.shuffle()
41
+ while True:
42
+ for item in dataset[str(split)]:
43
+ yield item[column]
44
+
45
+
46
+ def dataset_fn(split, shuffle_files, seed=None, dataset_params=None):
47
+ return tf.data.Dataset.from_generator(
48
+ functools.partial(gen_dataset, split, shuffle_files, seed, dataset_params=dataset_params),
49
+ output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
50
+ )
51
+
52
+
53
+
54
+ @utils.map_over_dataset
55
+ def target_to_key(x, key_map, target_key):
56
+ """Assign the value from the dataset to target_key in key_map"""
57
+ return {**key_map, target_key: x}
58
+
59
+
60
+
61
+ # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
62
+ dataset_name = 'NbAiLab/scandinavian'
63
+ dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
64
+ dataset_shapes = None
65
+ TaskRegistry.add(
66
+ "ncc_scandinavian_span_corruption_stream",
67
+ source=seqio.FunctionDataSource(
68
+ dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
69
+ splits=("train", "validation"),
70
+ caching_permitted=False,
71
+ num_input_examples=dataset_shapes,
72
+ ),
73
+ preprocessors=[
74
+ functools.partial(
75
+ target_to_key, key_map={
76
+ "inputs": None,
77
+ "targets": None,
78
+ }, target_key="targets"),
79
+ seqio.preprocessors.tokenize,
80
+ # seqio.CacheDatasetPlaceholder(),
81
+ preprocessors.span_corruption,
82
+ seqio.preprocessors.append_eos_after_trim,
83
+ ],
84
+ output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
85
+ metric_fns=[]
86
+ )
87
+
88
+
89
+
train_base.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/long-t5x"
2
+ T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ MODEL_DIR="gs://nb-t5x-us-central2/long_test_t5x_base"
4
+ export PYTHONPATH=${PROJECT_DIR}
5
+
6
+ python3 ${T5X_DIR}/t5x/train.py \
7
+ --gin_search_paths=${PROJECT_DIR} \
8
+ --gin_file="base.gin" \
9
+ --gin.MODEL_DIR="'${MODEL_DIR}'" \