aapot commited on
Commit
09eb1bf
1 Parent(s): ed007da

Try avoid hf hub git rate limits

Browse files
config.gin CHANGED
@@ -12,7 +12,7 @@ import tasks
12
 
13
  # Macros:
14
  # ==============================================================================
15
- BATCH_SIZE = 512
16
  DROPOUT_RATE = 0.0
17
  LABEL_SMOOTHING = 0.0
18
  LOSS_NORMALIZING_FACTOR = None
@@ -23,7 +23,7 @@ MODEL_DIR = '/researchdisk/t5x-small-nl24-finnish'
23
  OPTIMIZER = @adafactor.Adafactor()
24
  RANDOM_SEED = None
25
  SHUFFLE_TRAIN_EXAMPLES = True
26
- TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 114}
27
  TRAIN_STEPS = 500000
28
  USE_CACHED_TASKS = False
29
  USE_HARDWARE_RNG = False
@@ -123,7 +123,6 @@ network.T5Config.vocab_size = 32128
123
  train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
124
  train_script.train.eval_period = 10000
125
  train_script.train.eval_steps = 20
126
- train_script.train.hub_model_id = 'Finnish-NLP/t5x-small-nl24-finnish'
127
  train_script.train.infer_eval_dataset_cfg = None
128
  train_script.train.model = %MODEL
129
  train_script.train.model_dir = %MODEL_DIR
12
 
13
  # Macros:
14
  # ==============================================================================
15
+ BATCH_SIZE = 256
16
  DROPOUT_RATE = 0.0
17
  LABEL_SMOOTHING = 0.0
18
  LOSS_NORMALIZING_FACTOR = None
23
  OPTIMIZER = @adafactor.Adafactor()
24
  RANDOM_SEED = None
25
  SHUFFLE_TRAIN_EXAMPLES = True
26
+ TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512}
27
  TRAIN_STEPS = 500000
28
  USE_CACHED_TASKS = False
29
  USE_HARDWARE_RNG = False
123
  train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
124
  train_script.train.eval_period = 10000
125
  train_script.train.eval_steps = 20
 
126
  train_script.train.infer_eval_dataset_cfg = None
127
  train_script.train.model = %MODEL
128
  train_script.train.model_dir = %MODEL_DIR
config.json CHANGED
@@ -7,7 +7,7 @@
7
  "d_kv": 64,
8
  "d_model": 512,
9
  "decoder_start_token_id": 0,
10
- "dropout_rate": 0.0,
11
  "eos_token_id": 1,
12
  "feed_forward_proj": "gated-gelu",
13
  "initializer_factor": 1.0,
7
  "d_kv": 64,
8
  "d_model": 512,
9
  "decoder_start_token_id": 0,
10
+ "dropout_rate": 0.1,
11
  "eos_token_id": 1,
12
  "feed_forward_proj": "gated-gelu",
13
  "initializer_factor": 1.0,
small_nl24_pretrain.gin CHANGED
@@ -11,7 +11,6 @@ include 't5x/configs/runs/pretrain.gin'
11
  # ------------------- Training specification overrides --------------------------
12
  train_script.train:
13
  eval_period = 10000
14
- hub_model_id = "Finnish-NLP/t5x-small-nl24-finnish"
15
 
16
  utils.SaveCheckpointConfig:
17
  period = 10000
@@ -19,7 +18,7 @@ utils.SaveCheckpointConfig:
19
 
20
  MIXTURE_OR_TASK_NAME = "pretrain_finnish"
21
  USE_CACHED_TASKS = False
22
- TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114}
23
  TRAIN_STEPS = 500000
24
  DROPOUT_RATE = 0.0
25
- BATCH_SIZE = 512
11
  # ------------------- Training specification overrides --------------------------
12
  train_script.train:
13
  eval_period = 10000
 
14
 
15
  utils.SaveCheckpointConfig:
16
  period = 10000
18
 
19
  MIXTURE_OR_TASK_NAME = "pretrain_finnish"
20
  USE_CACHED_TASKS = False
21
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
22
  TRAIN_STEPS = 500000
23
  DROPOUT_RATE = 0.0
24
+ BATCH_SIZE = 256
start_train.sh CHANGED
@@ -2,10 +2,11 @@
2
  unset LD_PRELOAD
3
 
4
  PROJECT_DIR="/researchdisk/t5x-small-nl24-finnish"
 
5
  MODEL_DIR="/researchdisk/t5x-small-nl24-finnish"
6
  export PYTHONPATH=${PROJECT_DIR}
7
 
8
- python3 train.py \
9
  --gin_search_paths=${PROJECT_DIR} \
10
  --gin_file="small_nl24_pretrain.gin" \
11
  --gin.MODEL_DIR=\"${MODEL_DIR}\"
2
  unset LD_PRELOAD
3
 
4
  PROJECT_DIR="/researchdisk/t5x-small-nl24-finnish"
5
+ T5X_DIR=${HOME}"/t5x" # directory where the t5x is cloned.
6
  MODEL_DIR="/researchdisk/t5x-small-nl24-finnish"
7
  export PYTHONPATH=${PROJECT_DIR}
8
 
9
+ python3 ${T5X_DIR}/t5x/train.py \
10
  --gin_search_paths=${PROJECT_DIR} \
11
  --gin_file="small_nl24_pretrain.gin" \
12
  --gin.MODEL_DIR=\"${MODEL_DIR}\"
tasks.py CHANGED
@@ -1,49 +1,82 @@
 
 
1
  import functools
 
2
  import seqio
3
- from t5.evaluation import metrics
 
 
 
4
  from t5.data import preprocessors
 
 
5
 
6
- vocabulary = seqio.SentencePieceVocabulary('spiece.model')
7
- output_features = {
8
- 'inputs': seqio.Feature(vocabulary=vocabulary, add_eos=True, required=False),
9
- 'targets': seqio.Feature(vocabulary=vocabulary, add_eos=True)
 
 
 
 
 
 
10
  }
11
 
12
- seqio.TaskRegistry.add(
13
- 'pretrain_finnish',
14
- source=seqio.TextLineDataSource({
15
- "train": "/researchdisk/lm_training_dataset_full_sentences/train.txt",
16
- "validation": "/researchdisk/lm_training_dataset_full_sentences/validation.txt"
17
- }),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  preprocessors=[
19
  functools.partial(
20
- preprocessors.parse_tsv,
21
- field_names=["text"],
22
- field_delim="\n"),
23
- functools.partial(
24
- preprocessors.rekey, key_map={
25
  "inputs": None,
26
- "targets": "text"
27
- }),
28
  seqio.preprocessors.tokenize,
29
- seqio.CacheDatasetPlaceholder(),
30
- preprocessors.span_corruption,
31
  seqio.preprocessors.append_eos_after_trim,
32
  ],
33
- metric_fns=[metrics.accuracy],
34
- output_features=output_features)
35
-
36
- # dataset = seqio.get_mixture_or_task("pretrain_finnish").get_dataset(
37
- # sequence_length={"inputs": 512, "targets": 114},
38
- # split="train",
39
- # shuffle=True,
40
- # num_epochs=1,
41
- # #shard_info=seqio.ShardInfo(index=0, num_shards=10),
42
- # use_cached=False,
43
- # seed=42
44
- # )
45
-
46
-
47
- # # Print the first 5 examples.
48
- # for _, ex in zip(range(5), dataset.as_numpy_iterator()):
49
- # print(ex)
1
+ # adapted from https://huggingface.co/pere/pk-nb-t5x/blob/main/tasks.py
2
+
3
  import functools
4
+
5
  import seqio
6
+ import tensorflow as tf
7
+ import t5.data
8
+ from datasets import load_dataset, load_from_disk
9
+ from t5.data import postprocessors
10
  from t5.data import preprocessors
11
+ from t5.evaluation import metrics
12
+ from seqio import FunctionDataSource, utils
13
 
14
+ TaskRegistry = seqio.TaskRegistry
15
+
16
+ vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0)
17
+
18
+ DEFAULT_OUTPUT_FEATURES = {
19
+ "inputs": seqio.Feature(
20
+ vocabulary=vocabulary, add_eos=True,
21
+ required=False),
22
+ "targets": seqio.Feature(
23
+ vocabulary=vocabulary, add_eos=True)
24
  }
25
 
26
+
27
+ def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None):
28
+ if shuffle:
29
+ if seed:
30
+ dataset = dataset.shuffle(seed=seed)
31
+ else:
32
+ dataset = dataset.shuffle()
33
+ while True:
34
+ for item in dataset[str(split)]:
35
+ yield item[column]
36
+
37
+
38
+ def dataset_fn(split, shuffle_files, seed=None, dataset=None):
39
+ return tf.data.Dataset.from_generator(
40
+ functools.partial(gen_dataset, split, shuffle_files, seed, dataset=dataset),
41
+ output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
42
+ )
43
+
44
+
45
+ @utils.map_over_dataset
46
+ def target_to_key(x, key_map, target_key):
47
+ """Assign the value from the dataset to target_key in key_map"""
48
+ return {**key_map, target_key: x}
49
+
50
+
51
+ # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
52
+ dataset_name = "/researchdisk/lm_training_dataset_full"
53
+ dataset_params = {"from_disk_path": dataset_name}
54
+
55
+ if "from_disk_path" in dataset_params:
56
+ dataset = load_from_disk(dataset_params.get("from_disk_path"))
57
+ else:
58
+ dataset = load_dataset(**dataset_params)
59
+
60
+ dataset_shapes = {"train": dataset["train"].num_rows, "validation": dataset["validation"].num_rows}
61
+ TaskRegistry.add(
62
+ "pretrain_finnish",
63
+ source=seqio.FunctionDataSource(
64
+ dataset_fn=functools.partial(dataset_fn, dataset=dataset),
65
+ splits=("train", "validation"),
66
+ caching_permitted=False,
67
+ num_input_examples=dataset_shapes,
68
+ ),
69
  preprocessors=[
70
  functools.partial(
71
+ target_to_key, key_map={
 
 
 
 
72
  "inputs": None,
73
+ "targets": None,
74
+ }, target_key="targets"),
75
  seqio.preprocessors.tokenize,
76
+ # seqio.CacheDatasetPlaceholder(),
77
+ preprocessors.span_corruption,
78
  seqio.preprocessors.append_eos_after_trim,
79
  ],
80
+ output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
81
+ metric_fns=[metrics.accuracy]
82
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py DELETED
@@ -1,689 +0,0 @@
1
- # Copyright 2022 The T5X Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- r"""Script to pretrain or finetune in JAX using a SeqIO pipeline.
16
-
17
- """
18
- import functools
19
- import itertools
20
- import math
21
- import os
22
- import time
23
- from typing import Callable, Iterator, Sequence, Mapping, Tuple, Type, Optional
24
- import subprocess
25
-
26
- # Set Linen to add profiling information when constructing Modules.
27
- # Must be set before flax imports.
28
- # pylint:disable=g-import-not-at-top
29
- os.environ['FLAX_PROFILE'] = 'true'
30
- # TODO(adarob): Re-enable once users are notified and tests are updated.
31
- os.environ['FLAX_LAZY_RNG'] = 'no'
32
- from absl import logging
33
- from clu import metric_writers
34
- import jax
35
- from jax import random
36
- from jax.experimental import multihost_utils
37
- import jax.numpy as jnp
38
- import numpy as np
39
- import seqio
40
- from t5x import models
41
- from t5x import partitioning
42
- from t5x import train_state as train_state_lib
43
- from t5x import trainer as trainer_lib
44
- from t5x import utils
45
- from t5x import checkpoint_importer
46
- LazyArray = checkpoint_importer.LazyArray
47
- import tensorflow as tf
48
-
49
-
50
- # Automatically search for gin files relative to the T5X package.
51
- _DEFAULT_GIN_SEARCH_PATHS = [
52
- os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
53
- ]
54
- PyTreeDef = type(jax.tree_structure(None))
55
- P = partitioning.PartitionSpec
56
- # Special key that used to distinguish train metrics.
57
- TRAIN_METRIC_KEY = 'train'
58
- # String keys that is acceptable from config.
59
- _ACTION_KEYS = frozenset(trainer_lib.ActionMode.__members__.keys())
60
-
61
-
62
- def run_actions(
63
- mode: trainer_lib.ActionMode, actions: trainer_lib.ActionMapType,
64
- train_state: train_state_lib.TrainState,
65
- metrics_by_task: Mapping[str, trainer_lib.MetricValueMapType]) -> bool:
66
- """Invokes all actions on the given mode on host 0, then broadcasts to all.
67
-
68
- Args:
69
- mode: The mode to run the actions. e.g., if mode is `train`, only actions
70
- configured to run with `train` mode will be invoked.
71
- actions: A mapping of actions that runs after train, eval or infer_eval, to
72
- inspect the model and perform useful operations, e.g., early stopping.
73
- train_state: The current train_state of the trainer.
74
- metrics_by_task: A map of metrics keyed by task name.
75
-
76
- Returns:
77
- A bool indicating whether training should be halted.
78
-
79
- Raises:
80
- RuntimeError: When the metrics processed on host 0 is None.
81
- """
82
- stop_training = False
83
- if jax.process_index() == 0:
84
- if not metrics_by_task:
85
- raise RuntimeError('Metric is unexpectedly empty on process 0')
86
- for action in actions.get(mode, []):
87
- stop_training |= action.run(train_state, metrics_by_task=metrics_by_task)
88
- # Broadcast result from host 0 to others.
89
- return bool(multihost_utils.broadcast_one_to_all(jnp.array(stop_training)))
90
-
91
-
92
- def train(
93
- *,
94
- model: models.BaseTransformerModel,
95
- train_dataset_cfg: utils.DatasetConfig,
96
- train_eval_dataset_cfg: Optional[utils.DatasetConfig],
97
- infer_eval_dataset_cfg: Optional[utils.DatasetConfig],
98
- checkpoint_cfg: utils.CheckpointConfig,
99
- partitioner: partitioning.BasePartitioner,
100
- trainer_cls: Type[trainer_lib.BaseTrainer],
101
- model_dir: str,
102
- total_steps: int,
103
- eval_steps: int,
104
- eval_period: int,
105
- stats_period: Optional[int] = None,
106
- random_seed: Optional[int],
107
- use_hardware_rng: bool = False,
108
- summarize_config_fn: Callable[[str, metric_writers.MetricWriter, int],
109
- None],
110
- inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator,
111
- get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset,
112
- concurrent_metrics: bool = True,
113
- actions: Optional[Mapping[str, Sequence[trainer_lib.BaseAction]]] = None,
114
- train_eval_get_dataset_fn: Optional[utils.GetDatasetCallable] = None,
115
- run_eval_before_training: bool = False,
116
- hub_model_id: str = None,
117
- ) -> Tuple[int, train_state_lib.TrainState]:
118
- """Train function.
119
-
120
- Args:
121
- model: The model object to use for training.
122
- train_dataset_cfg: Specification for the dataset to train with.
123
- train_eval_dataset_cfg: Specification for the dataset to evaluate with using
124
- the train metrics and no inference (e.g., uses teacher forcing). If None,
125
- train eval is disabled.
126
- infer_eval_dataset_cfg: Specification for the dataset to evaluate with using
127
- the inference metrics (e.g., uses sampled decoding). If None, inference
128
- eval is disabled.
129
- checkpoint_cfg: Specification for saving and restoring model parameters and
130
- dataset state to/from checkpoints.
131
- partitioner: Partitioner for model parameters and data across devices.
132
- trainer_cls: An implementation of BaseTrainer.
133
- model_dir: Path of directory to store checkpoints and metric summaries.
134
- total_steps: The step number to stop training after. The number of actual
135
- steps trained in this run will be this number minus the starting step from
136
- the checkpoint.
137
- eval_steps: The number of batches to process for each train-eval loop.
138
- eval_period: The number of train steps between each evaluation (both
139
- train-eval and infer-eval).
140
- stats_period: The number of train steps between writing scalar stats. If
141
- None, defaults to eval_period.
142
- random_seed: A random seed to use for dropout and initialization. If None, a
143
- fast, non-deterministic hardware-based RNG is used.
144
- use_hardware_rng: Whether to force using the RngBitGenerator based hardware
145
- rng, which takes seeds and acts similarly to software PRNG in that it
146
- should be seed-deterministic. The new RngBitGenerator custom PRNG system
147
- should be reproducible for a given sharding, but the numbers will change
148
- for different shardings of the same model.
149
- summarize_config_fn: A function that takes in the model directory, a
150
- SummaryWriter, and the step number, and writes a summary of the
151
- inference_evaluator_cls: seqio.Evaluator class to use for inference
152
- evaluation, potentially with bound configuration args.
153
- get_dataset_fn: The callable use to get the train and train-eval datasets
154
- based on the DatasetConfig and shard information.
155
- concurrent_metrics: If True, allow metrics computation and logging to
156
- overlap with training. Will likely result in additional TPU memory usage.
157
- actions: A mapping of actions that runs after train, eval or infer_eval, to
158
- inspect the model and perform useful operations, e.g., early stopping. The
159
- key must have a 1:1 mapping to ActionMode enum. For EVAL actions to
160
- actually work, this requires `concurrent_metrics` to be turned off,
161
- since chaining futures and mutating states concurrently might be
162
- error-prone.
163
- train_eval_get_dataset_fn: Optional callable use to get the train-eval
164
- datasets based on the DatasetConfig and shard information. If missing, it
165
- defaults to `get_dataset_fn`.
166
- run_eval_before_training: If True, calculate training eval and inference
167
- eval metrics before training begins.
168
-
169
- Returns:
170
- The tuple of (last_step, last_train_state).
171
- """
172
- logging.info('Process ID: %d', jax.process_index())
173
- tf.io.gfile.makedirs(model_dir)
174
-
175
- # Each "epoch" of the training loop should be the min of the eval period,
176
- # checkpoint period or the full training.
177
- # We compute here to ensure that the eval period and checkpoint period are
178
- # divisible by this number, otherwise we fail.
179
- eval_enabled = (train_eval_dataset_cfg or infer_eval_dataset_cfg)
180
- eval_period = eval_period if eval_enabled else 0
181
- checkpoint_period = checkpoint_cfg.save.period if checkpoint_cfg.save else 0
182
- if eval_period or checkpoint_period:
183
- steps_per_epoch = min(eval_period or np.inf, checkpoint_period or np.inf)
184
- else:
185
- steps_per_epoch = total_steps
186
- stats_period = stats_period or steps_per_epoch
187
- if (eval_period and eval_period % steps_per_epoch or
188
- checkpoint_period and checkpoint_period % steps_per_epoch):
189
- raise ValueError(
190
- f'Checkpoint period ({checkpoint_period}) must evenly divide eval '
191
- f'period ({eval_period}), or vice-versa.')
192
-
193
- if use_hardware_rng or random_seed is None:
194
- logging.info(
195
- 'Using fast RngBitGenerator PRNG for initialization and dropout.')
196
-
197
- if random_seed is None:
198
- random_seed = multihost_utils.broadcast_one_to_all(np.int32(time.time()))
199
- logging.info('Random seed not provided, using RNG seed %s', random_seed)
200
- else:
201
- logging.warning(
202
- 'When using hardware RNG with a fixed seed, repeatability is only '
203
- 'guaranteed for fixed hardware and partitioning schemes and for a '
204
- 'fixed version of this code and its dependencies.')
205
- utils.set_hardware_rng_ops()
206
- rng = random.PRNGKey(random_seed)
207
- else:
208
- logging.info('Using seed for initialization and dropout RNG: %d',
209
- random_seed)
210
- rng = random.PRNGKey(random_seed)
211
-
212
- init_rng, trainer_rng = random.split(rng, 2)
213
-
214
- # ---------------------------------------------------------------------------
215
- # Initialize datasets
216
- # ---------------------------------------------------------------------------
217
-
218
- if (train_dataset_cfg.seed and
219
- not (checkpoint_cfg.save or checkpoint_cfg.save.save_dataset)):
220
- logging.warning(
221
- 'Providing a random seed for the train dataset with '
222
- '`checkpoint_train_ds=False` is dangerous since each '
223
- 'preemption/restart will cause the dataset to deterministically replay '
224
- 'from the beginning.')
225
-
226
- data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size)
227
- ds_shard_id = data_layout.shard_id
228
- num_ds_shards = data_layout.num_shards
229
-
230
- def _verify_matching_vocabs(cfg: utils.DatasetConfig):
231
- ds_vocabs = utils.get_vocabulary(cfg)
232
- if (ds_vocabs[0] != model.input_vocabulary or
233
- ds_vocabs[1] != model.output_vocabulary):
234
- raise ValueError(f'Model and Task vocabularies do not match:\n'
235
- f' task={cfg.mixture_or_task_name}\n'
236
- f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n'
237
- f' model.input_vocabulary={model.input_vocabulary}\n'
238
- f' model.output_vocabulary={model.output_vocabulary}\n')
239
-
240
- _verify_matching_vocabs(train_dataset_cfg)
241
-
242
- train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
243
- model.FEATURE_CONVERTER_CLS)
244
-
245
- if train_eval_dataset_cfg:
246
- _verify_matching_vocabs(train_eval_dataset_cfg)
247
- train_eval_datasets = utils.get_training_eval_datasets(
248
- train_eval_dataset_cfg,
249
- ds_shard_id,
250
- num_ds_shards,
251
- eval_steps,
252
- model.FEATURE_CONVERTER_CLS,
253
- get_dataset_fn=train_eval_get_dataset_fn if train_eval_get_dataset_fn
254
- is not None else get_dataset_fn) # type: Mapping[str, tf.data.Dataset]
255
- if not train_eval_datasets:
256
- logging.warning(
257
- 'No train_eval datasets loaded from config `train_eval_dataset_cfg`: '
258
- '%s', train_eval_dataset_cfg)
259
- else:
260
- train_eval_datasets = {}
261
-
262
- # Initialize optimizer, maybe from an existing checkpoint.
263
- checkpointable_train_iter: tf.data.Iterator = iter(train_ds) # pytype:disable=annotation-type-mismatch
264
- train_iter: Iterator[trainer_lib.BatchType] = map(
265
- lambda x: jax.tree_map(np.array, x), checkpointable_train_iter)
266
-
267
- # The manner in which parameters are initialized follows this order of
268
- # preference:
269
- # 1. From a T5X checkpoint in `model_dir`, if one exists.
270
- # 2. From a T5X or TF checkpoint specified by `cfg.path`, if set.
271
- # 3. From scratch using `init_fn`.
272
-
273
- # 1. From a T5X checkpoint in `model_dir`, if one exists.
274
- if checkpoint_cfg.restore is not None:
275
- state_transforms_for_restore = [
276
- functools.partial(fn, is_resuming=True)
277
- for fn in checkpoint_cfg.restore.state_transformation_fns
278
- ]
279
- else:
280
- state_transforms_for_restore = []
281
- restore_cfgs = [
282
- utils.RestoreCheckpointConfig(
283
- path=model_dir,
284
- mode='latest',
285
- dtype=checkpoint_cfg.save.dtype,
286
- checkpointer_cls=checkpoint_cfg.save.checkpointer_cls,
287
- # Restore dataset state if it is being saved.
288
- restore_dataset=(checkpoint_cfg.save and
289
- checkpoint_cfg.save.save_dataset),
290
- state_transformation_fns=state_transforms_for_restore)
291
- ]
292
- # 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set.
293
- if checkpoint_cfg.restore:
294
- if checkpoint_cfg.restore.mode == 'all':
295
- raise ValueError(
296
- "Restore checkpoint mode 'all' is not supported in training.")
297
-
298
- # TODO(dhgarrette): Split "restore" behavior into separate configurations
299
- # for the initial restoration for a new run, vs resuming a stopped run.
300
- if isinstance(checkpoint_cfg.restore.path, str):
301
- restore_cfgs.append(checkpoint_cfg.restore)
302
- elif not checkpoint_cfg.restore.path:
303
- # `path` is an empty (non-`str`) sequence, so there is nothing to restore.
304
- pass
305
- else:
306
- raise ValueError(
307
- 'Restore checkpoint config may only have a single path in training.')
308
-
309
- # Need to use full batch size.
310
- input_shapes = {
311
- k: (data_layout.batch_size, *v.shape[1:])
312
- for k, v in train_ds.element_spec.items()
313
- }
314
- input_types = {
315
- k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items()
316
- }
317
- init_or_restore_tick = time.time()
318
- train_state_initializer = utils.TrainStateInitializer(
319
- optimizer_def=model.optimizer_def,
320
- init_fn=model.get_initial_variables,
321
- input_shapes=input_shapes,
322
- input_types=input_types,
323
- partitioner=partitioner)
324
- # 3. From scratch using `init_fn`.
325
- train_state = train_state_initializer.from_checkpoint_or_scratch(
326
- restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)
327
- train_state_axes = train_state_initializer.train_state_axes
328
- init_or_restore_secs = time.time() - init_or_restore_tick
329
- logging.info('Initialize/restore complete (%.2f seconds).',
330
- init_or_restore_secs)
331
-
332
- # Log the variable shapes information and write to a file.
333
- log_file = os.path.join(model_dir, 'model-info.txt')
334
- utils.log_model_info(log_file,
335
- train_state_initializer.global_train_state_shape,
336
- partitioner)
337
-
338
- if checkpoint_period:
339
- checkpointer = checkpoint_cfg.save.checkpointer_cls(
340
- train_state=train_state_initializer.global_train_state_shape,
341
- partitioner=partitioner,
342
- checkpoints_dir=model_dir,
343
- dataset_iterator=(checkpointable_train_iter
344
- if checkpoint_cfg.save.save_dataset else None),
345
- save_dtype=checkpoint_cfg.save.dtype,
346
- keep=checkpoint_cfg.save.keep)
347
-
348
-
349
- # Restore step from last checkpoint or set to 0 if training from scratch.
350
- host_step = int(train_state.step)
351
-
352
- # ---------------------------------------------------------------------------
353
- # Trainer
354
- # ---------------------------------------------------------------------------
355
-
356
- trainer: trainer_lib.BaseTrainer = trainer_cls(
357
- model=model,
358
- train_state=train_state,
359
- partitioner=partitioner,
360
- train_state_axes=train_state_axes,
361
- eval_names=train_eval_datasets.keys(),
362
- summary_dir=model_dir,
363
- rng=trainer_rng)
364
- del train_state
365
-
366
- train_metrics = trainer.train_metrics_manager
367
- summarize_config_fn(model_dir, train_metrics.summary_writer, host_step)
368
-
369
- train_metrics.write_scalar('timing/init_or_restore_seconds',
370
- init_or_restore_secs, host_step)
371
-
372
- # ----------------------------------------------------------------------------
373
- # SeqIO (inference-based) evaluation setup
374
- # ----------------------------------------------------------------------------
375
- # Init evaluator to set up cached datasets
376
- evaluator = None
377
- if infer_eval_dataset_cfg is not None:
378
- _verify_matching_vocabs(infer_eval_dataset_cfg)
379
- evaluator = inference_evaluator_cls(
380
- log_dir=os.path.join(model_dir, 'inference_eval'),
381
- mixture_or_task_name=infer_eval_dataset_cfg.mixture_or_task_name,
382
- feature_converter=model.FEATURE_CONVERTER_CLS(pack=False),
383
- eval_split=infer_eval_dataset_cfg.split,
384
- use_cached=infer_eval_dataset_cfg.use_cached,
385
- seed=infer_eval_dataset_cfg.seed,
386
- sequence_length=infer_eval_dataset_cfg.task_feature_lengths,
387
- use_memory_cache=infer_eval_dataset_cfg.use_memory_cache)
388
- if not evaluator.eval_tasks:
389
- # Skip evaluaton.
390
- evaluator = None
391
-
392
- if evaluator is not None:
393
- predict_fn = utils.get_infer_fn(
394
- infer_step=model.predict_batch,
395
- batch_size=infer_eval_dataset_cfg.batch_size,
396
- train_state_axes=train_state_axes,
397
- partitioner=partitioner)
398
-
399
- score_fn = utils.get_infer_fn(
400
- infer_step=model.score_batch,
401
- batch_size=infer_eval_dataset_cfg.batch_size,
402
- train_state_axes=train_state_axes,
403
- partitioner=partitioner)
404
-
405
- if actions is None:
406
- actions = {}
407
-
408
- if set(actions.keys()).difference(_ACTION_KEYS):
409
- raise ValueError(f'actions keys must be one of {_ACTION_KEYS}, but got : '
410
- f'{actions.keys()}')
411
-
412
- # Transform the string key into proper ActionMode enum.
413
- actions = {trainer_lib.ActionMode[k]: v for k, v in actions.items()}
414
-
415
- if concurrent_metrics and actions.get(trainer_lib.ActionMode.INFER_EVAL,
416
- None) is not None:
417
- logging.warning('Actions for INFER_EVAL will not be triggered when async '
418
- 'metrics computation is enabled')
419
- if concurrent_metrics and actions.get(trainer_lib.ActionMode.TRAIN,
420
- None) is not None:
421
- logging.warning('Actions for TRAIN will not be triggered when async '
422
- 'metrics computation is enabled')
423
-
424
- # ----------------------------------------------------------------------------
425
- # Setup Eval Utility Functions
426
- # ----------------------------------------------------------------------------
427
- def _run_training_eval(first_run: bool = False):
428
- if first_run:
429
- logging.info('Compiling training eval loop.')
430
- trainer.compile_eval({
431
- task: utils.get_zeros_batch_like_dataset(ds)
432
- for task, ds in train_eval_datasets.items()
433
- })
434
- logging.info('Computing training evaluation metrics.')
435
- eval_batch_iters = {
436
- task: ds.as_numpy_iterator()
437
- for task, ds in train_eval_datasets.items()
438
- }
439
- eval_summaries = trainer.eval(eval_batch_iters)
440
- trainer.stop_training = run_actions(trainer_lib.ActionMode.TRAIN_EVAL,
441
- actions, trainer.train_state,
442
- eval_summaries)
443
-
444
- def _run_inference_eval():
445
- """Run prediction based inference eval."""
446
- if evaluator is None:
447
- return
448
- logging.info('Running inference evaluation.')
449
- evaluate_tick = time.time()
450
- all_metrics, _, _ = evaluator.evaluate(
451
- compute_metrics=jax.process_index() == 0,
452
- step=host_step,
453
- predict_fn=functools.partial(
454
- predict_fn,
455
- train_state=trainer.train_state,
456
- rng=jax.random.PRNGKey(0)),
457
- score_fn=functools.partial(score_fn, train_state=trainer.train_state))
458
- if not concurrent_metrics:
459
- # Ensure metrics are finished being computed.
460
- all_metrics_done = all_metrics.result() or {}
461
- trainer.stop_training = run_actions(trainer_lib.ActionMode.INFER_EVAL,
462
- actions, trainer.train_state,
463
- all_metrics_done)
464
- train_metrics.write_scalar('timing/evaluate_seconds',
465
- time.time() - evaluate_tick, host_step)
466
-
467
- # Optionally run teacher-forcing training eval and SeqIO inference-base eval
468
- # before training. Useful for testing how much a model knows before any
469
- # finetuning.
470
- if run_eval_before_training:
471
- if train_eval_datasets:
472
- logging.info('Running training eval before training.')
473
- _run_training_eval(first_run=True)
474
- if evaluator is not None:
475
- logging.info('Running inference eval before training.')
476
- _run_inference_eval()
477
-
478
- # ----------------------------------------------------------------------------
479
- # Main training loop
480
- # ----------------------------------------------------------------------------
481
- logging.info('Starting training loop.')
482
-
483
- first_step = host_step
484
-
485
- if total_steps < first_step:
486
- raise ValueError(
487
- f'Unexpected total_steps ({total_steps}) < checkpoint step '
488
- f' ({first_step}).')
489
-
490
- logging.info('Starting main loop over steps %d-%d', first_step, total_steps)
491
-
492
- steps_per_epoch = min(steps_per_epoch, total_steps)
493
- first_epoch = first_step // steps_per_epoch
494
- num_epochs = first_epoch + math.ceil(
495
- (total_steps - first_step) / steps_per_epoch)
496
- logging.info('Training with artificial "epochs" of %d steps.',
497
- steps_per_epoch)
498
-
499
- # Kickstart training dataset and compile train loop.
500
- logging.info('Kickstarting train dataset prefetch.')
501
- logging.flush()
502
-
503
- ds_tick = time.time()
504
- # Get first batch to warm up the dataset pipeline.
505
- first_batch = next(train_iter)
506
- # Prepend first batch back to iterator to be used by trainer.
507
- train_iter = itertools.chain([first_batch], train_iter)
508
- train_metrics.write_scalar('timing/dataset_warmup_seconds',
509
- time.time() - ds_tick, host_step)
510
- logging.info('Compiling train loop.')
511
- logging.flush()
512
- trainer.compile_train(first_batch)
513
-
514
- # Main Loop over "epochs".
515
- for epoch in range(first_epoch, num_epochs):
516
- final_epoch = epoch == num_epochs - 1
517
- logging.info('Epoch %d of %d', epoch, num_epochs)
518
-
519
- # `stop_training` is requested, break out the main loop immediately.
520
- if trainer.stop_training:
521
- break
522
-
523
- logging.info('BEGIN Train loop.')
524
- try:
525
- # Until the last epoch, `num_steps = steps_per_epoch`
526
- num_steps = min(total_steps - host_step, steps_per_epoch)
527
- epoch_end_step = host_step + num_steps
528
- logging.info('Training for %d steps.', num_steps)
529
- while host_step < epoch_end_step:
530
- if trainer.stop_training:
531
- logging.info('Saving a checkpoint before early stopping...')
532
- checkpointer.save(trainer.train_state,
533
- checkpoint_cfg.save.state_transformation_fns)
534
-
535
- if hub_model_id:
536
- # convert checkpoint to HF Flax model and push to hub
537
- checkpoint_step = trainer.train_state.step
538
- checkpoint_step = checkpoint_step.get() if isinstance(checkpoint_step, LazyArray) else checkpoint_step
539
- checkpoint_step = int(checkpoint_step) # Integer, to avoid side effects in the checkpoint path.
540
- config_path = os.path.join(model_dir, 'config.json')
541
- subprocess.run(["python3", "convert_t5x_checkpoint_to_flax.py", f"--t5x_checkpoint_path='checkpoint_{checkpoint_step}'/'", f'--config_name="{config_path}"', "--flax_dump_folder_path='./'"])
542
- subprocess.run("git lfs prune --verify-remote", shell=True)
543
- subprocess.run("git add .", shell=True)
544
- subprocess.run(f'git commit -m "Saving weights and logs of step {checkpoint_step}"', shell=True)
545
- subprocess.Popen("git push", shell=True)
546
-
547
- logging.info('Stopping training loop early since `stop_training` is '
548
- 'requested.')
549
- break
550
-
551
- inner_num_steps = min(epoch_end_step - host_step, stats_period)
552
- train_summary = trainer.train(
553
- train_iter, inner_num_steps, start_step=host_step)
554
- if not concurrent_metrics:
555
- # Note that we always pass the dictionary of `tasks` -> summary so
556
- # that the actions can be performed without special casing. The only
557
- # caveat is that train would need its own special `key` given no
558
- # `task` will be applied.
559
- trainer.stop_training = run_actions(
560
- trainer_lib.ActionMode.TRAIN, actions, trainer.train_state,
561
- {TRAIN_METRIC_KEY: train_summary.result()})
562
-
563
- host_step += inner_num_steps
564
- logging.info('END Train loop.')
565
- except trainer_lib.PreemptionError as e:
566
- logging.info('Saving emergency checkpoint.')
567
- checkpointer.save(trainer.train_state,
568
- checkpoint_cfg.save.state_transformation_fns)
569
- logging.info('Saving emergency checkpoint done.')
570
- raise e
571
-
572
- step_offset = host_step - first_step
573
-
574
- is_eval_epoch = eval_period and (final_epoch or
575
- step_offset % eval_period == 0)
576
-
577
- # Training Evaluation (i.e., with teacher forcing).
578
- if is_eval_epoch and train_eval_datasets:
579
- # Maybe less if final step < period.
580
- first_run = step_offset // eval_period <= 1
581
- _run_training_eval(first_run and not run_eval_before_training)
582
-
583
- # Maybe save a checkpoint.
584
- if checkpoint_period and (final_epoch or
585
- step_offset % checkpoint_period == 0):
586
- # Make sure last train step has completed before starting the clock.
587
- train_summary.result()
588
- logging.info('Saving checkpoint.')
589
- checkpoint_tick = time.time()
590
- checkpointer.save(trainer.train_state,
591
- checkpoint_cfg.save.state_transformation_fns)
592
- checkpoint_tock = time.time()
593
- train_metrics.write_scalar('timing/checkpoint_seconds',
594
- checkpoint_tock - checkpoint_tick, host_step)
595
-
596
- if hub_model_id:
597
- # convert checkpoint to HF Flax model and push to hub
598
- checkpoint_step = trainer.train_state.step
599
- checkpoint_step = checkpoint_step.get() if isinstance(checkpoint_step, LazyArray) else checkpoint_step
600
- checkpoint_step = int(checkpoint_step) # Integer, to avoid side effects in the checkpoint path.
601
- config_path = os.path.join(model_dir, 'config.json')
602
- subprocess.run(["python3", "convert_t5x_checkpoint_to_flax.py", f"--t5x_checkpoint_path='checkpoint_{checkpoint_step}'/'", f'--config_name="{config_path}"', "--flax_dump_folder_path='./'"])
603
- subprocess.run("git lfs prune --verify-remote", shell=True)
604
- subprocess.run("git add .", shell=True)
605
- subprocess.run(f'git commit -m "Saving weights and logs of step {checkpoint_step}"', shell=True)
606
- subprocess.Popen("git push", shell=True)
607
-
608
- # Inference Evaluation (i.e., with decoding or scoring).
609
- if evaluator is not None:
610
- _run_inference_eval()
611
-
612
- # Wait until computations are done before exiting
613
- logging.info('Finished.')
614
- trainer.close()
615
- if evaluator:
616
- evaluator.close()
617
- multihost_utils.sync_global_devices('complete')
618
-
619
- return host_step, trainer.train_state
620
-
621
-
622
- if __name__ == '__main__':
623
- # pylint: disable=g-import-not-at-top
624
- from absl import app
625
- from absl import flags
626
- import gin
627
- from t5x import gin_utils
628
- # pylint: enable=g-import-not-at-top
629
-
630
- FLAGS = flags.FLAGS
631
-
632
- jax.config.parse_flags_with_absl()
633
-
634
- flags.DEFINE_multi_string(
635
- 'gin_file',
636
- default=None,
637
- help='Path to gin configuration file. Multiple paths may be passed and '
638
- 'will be imported in the given order, with later configurations '
639
- 'overriding earlier ones.')
640
-
641
- flags.DEFINE_multi_string(
642
- 'gin_bindings', default=[], help='Individual gin bindings.')
643
-
644
- flags.DEFINE_list(
645
- 'gin_search_paths',
646
- default=['.'],
647
- help='Comma-separated list of gin config path prefixes to be prepended '
648
- 'to suffixes given via `--gin_file`. If a file appears in. Only the '
649
- 'first prefix that produces a valid path for each suffix will be '
650
- 'used.')
651
-
652
- flags.DEFINE_string(
653
- 'tfds_data_dir', None,
654
- 'If set, this directory will be used to store datasets prepared by '
655
- 'TensorFlow Datasets that are not available in the public TFDS GCS '
656
- 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
657
- 'all `Task`s.')
658
-
659
- flags.DEFINE_list(
660
- 'seqio_additional_cache_dirs', [],
661
- 'Directories to search for cached Tasks in addition to defaults.')
662
-
663
-
664
-
665
- def main(argv: Sequence[str]):
666
- """Wrapper for pdb post mortems."""
667
- _main(argv)
668
-
669
- def _main(argv: Sequence[str]):
670
- """True main function."""
671
- if len(argv) > 1:
672
- raise app.UsageError('Too many command-line arguments.')
673
-
674
- if FLAGS.tfds_data_dir:
675
- seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
676
-
677
- seqio.add_global_cache_dirs(FLAGS.seqio_additional_cache_dirs)
678
-
679
- # Create gin-configurable version of `train`.
680
- train_using_gin = gin.configurable(train)
681
-
682
- gin_utils.parse_gin_flags(
683
- # User-provided gin paths take precedence if relative paths conflict.
684
- FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
685
- FLAGS.gin_file,
686
- FLAGS.gin_bindings)
687
- train_using_gin()
688
-
689
- gin_utils.run(main)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train/events.out.tfevents.1649073594.t1v-n-304587cf-w-0.1316481.0.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:224a0411c5fc4e0e882c7a647ff554b58fec3f79dc12f9809b26b3a319225c1d
3
- size 7585
 
 
 
train/events.out.tfevents.1649092520.t1v-n-304587cf-w-0.1399566.0.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d5a8424af960443ad6fbb097216b8add4fb9af5298424f440afb56a27ee260b9
3
- size 16363
 
 
 
train/{events.out.tfevents.1649056216.t1v-n-304587cf-w-0.1239745.0.v2 → events.out.tfevents.1649705066.t1v-n-304587cf-w-0.2549834.0.v2} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:766ae4ead7af1adb3e58f71c6b05f02a42c73eea497118cb6250ffa8c1c0bd2c
3
- size 7581
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59fc20fda0f88a5e31f18b0ebc9497d4131b6458ac28d6d2a52875d1ba5c5b13
3
+ size 10402
training_eval/pretrain_finnish/events.out.tfevents.1649073594.t1v-n-304587cf-w-0.1316481.1.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a4fbca2952ba8dbad98a5e752b6bcf0e80ea9d1376380bbbbd669b7fb0897e7
3
- size 1431
 
 
 
training_eval/pretrain_finnish/events.out.tfevents.1649092520.t1v-n-304587cf-w-0.1399566.1.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b3efa0a2b6af0ef032441aaf0e97ac69bab14c84b90bbbfa22001dc094926fb2
3
- size 9261
 
 
 
training_eval/pretrain_finnish/{events.out.tfevents.1649056216.t1v-n-304587cf-w-0.1239745.1.v2 → events.out.tfevents.1649705066.t1v-n-304587cf-w-0.2549834.1.v2} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dec4b61e593fed96fc1b3b8511230bfb38bd4c2a2d1e780269d761122476c3f0
3
- size 1431
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ec5c452edec8f5036cab8e5f3d67f492e55a15ff004d4ee5847b2d2cd56f2df
3
+ size 4024