3v324v23 commited on
Commit
e557baa
1 Parent(s): cb25bdc

Update weights to checkpoint 140000

Browse files
README.md CHANGED
@@ -25,15 +25,40 @@ For a demo of the model, head over to the Hugging Face Spaces for the **[Netherf
25
 
26
  ## Dataset
27
 
 
28
  `t5-base-dutch-demo` is fine-tuned on three mixed news sources:
29
 
30
  1. **CNN DailyMail** translated to Dutch with MarianMT.
31
  2. **XSUM** translated to Dutch with MarianMt.
32
  3. News article summaries distilled from the nu.nl website.
 
 
33
 
34
  ## Training
35
 
36
- The pre-trained model [t5-base-dutch](https://huggingface.co/flax-community/t5-base-dutch) was fine-tuned with a constant learning rate of 0.0005 and a batch size of 64 for 10.000 steps.
37
- The performance of this model can be improved with longer training. Unfortunately due to a bug, an earlier training script would not save intermediate checkpoints and had been started for 6 epochs, which would have it finish past the TPU-VM availability schedule. Since there was limited time left, the fine-tuning was restarted without evaluation and for only half an epoch (10.000 steps).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
39
 
 
 
 
25
 
26
  ## Dataset
27
 
28
+
29
  `t5-base-dutch-demo` is fine-tuned on three mixed news sources:
30
 
31
  1. **CNN DailyMail** translated to Dutch with MarianMT.
32
  2. **XSUM** translated to Dutch with MarianMt.
33
  3. News article summaries distilled from the nu.nl website.
34
+
35
+ The total number of training examples in this dataset is 1366592.
36
 
37
  ## Training
38
 
39
+ Training consisted of fine-tuning [t5-base-dutch](https://huggingface.co/flax-community/t5-base-dutch) with
40
+ the following parameters:
41
+
42
+ * Constant learning rate 0.0005
43
+ * Batch size 8
44
+ * 1 epoch (170842 steps)
45
+
46
+ ## Evaluation
47
+
48
+ The performance of the summarization model is measured with the Rouge metric from the
49
+ Huggingface Datasets library.
50
+
51
+ ```
52
+ "rouge{n}" (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
53
+ "rougeL": Longest common subsequence based scoring.
54
+ "rougeLSum": rougeLsum splits text using "
55
+ "
56
+ ```
57
 
58
+ * Rouge1: 28.7066
59
+ * Rouge2: 9.5498
60
+ * RougeL: 22.8103
61
+ * rougeLsum: 24.2696
62
 
63
+ These scores are expected to improve when the model is trained and evaluation configured
64
+ for the CNN DM and XSUM datasets (translated to Dutch) individually.
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "flax-community/t5-base-dutch-demo",
3
  "architectures": [
4
  "T5ForConditionalGeneration"
5
  ],
 
1
  {
2
+ "_name_or_path": "./",
3
  "architectures": [
4
  "T5ForConditionalGeneration"
5
  ],
output/events.out.tfevents.1626477704.t1v-n-0e7426e8-w-0.83817.3.v2 → events.out.tfevents.1626708806.yeb-z390-k80.10632.3.v2 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:091b61fc500aae0368d977c5c0fd73632a32aabebdc0e7fba4129f26b6c8abdf
3
- size 6630102
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daeab64aaf6dd18fc097ee6bed7cd5e4e765e75716ca80c47777ad3b849b3679
3
+ size 19440898
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8299c056e5ca07f93db2db052d61cb941710e0925c62486ee0c9775116e0a6bf
3
  size 891548548
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ba1de1818d02f938ad913208487e569e15df1ce66ea9a2fa9580bb9f2a32f19
3
  size 891548548
output/ckpt-9999/config.json DELETED
@@ -1,58 +0,0 @@
1
- {
2
- "_name_or_path": ".",
3
- "architectures": [
4
- "T5ForConditionalGeneration"
5
- ],
6
- "d_ff": 3072,
7
- "d_kv": 64,
8
- "d_model": 768,
9
- "decoder_start_token_id": 0,
10
- "dropout_rate": 0.1,
11
- "eos_token_id": 1,
12
- "feed_forward_proj": "relu",
13
- "gradient_checkpointing": false,
14
- "initializer_factor": 1.0,
15
- "is_encoder_decoder": true,
16
- "layer_norm_epsilon": 1e-06,
17
- "model_type": "t5",
18
- "n_positions": 512,
19
- "num_decoder_layers": 12,
20
- "num_heads": 12,
21
- "num_layers": 12,
22
- "output_past": true,
23
- "pad_token_id": 0,
24
- "relative_attention_num_buckets": 32,
25
- "task_specific_params": {
26
- "summarization": {
27
- "early_stopping": true,
28
- "length_penalty": 2.0,
29
- "max_length": 200,
30
- "min_length": 30,
31
- "no_repeat_ngram_size": 3,
32
- "num_beams": 4,
33
- "prefix": "summarize: "
34
- },
35
- "translation_en_to_de": {
36
- "early_stopping": true,
37
- "max_length": 300,
38
- "num_beams": 4,
39
- "prefix": "translate English to German: "
40
- },
41
- "translation_en_to_fr": {
42
- "early_stopping": true,
43
- "max_length": 300,
44
- "num_beams": 4,
45
- "prefix": "translate English to French: "
46
- },
47
- "translation_en_to_ro": {
48
- "early_stopping": true,
49
- "max_length": 300,
50
- "num_beams": 4,
51
- "prefix": "translate English to Romanian: "
52
- }
53
- },
54
- "torch_dtype": "float32",
55
- "transformers_version": "4.9.0.dev0",
56
- "use_cache": true,
57
- "vocab_size": 32103
58
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
output/ckpt-9999/flax_model.msgpack DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8299c056e5ca07f93db2db052d61cb941710e0925c62486ee0c9775116e0a6bf
3
- size 891548548
 
 
 
 
output/ckpt-9999/opt_state.msgpack DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c912197fd24feea06a22802e5bfbd9935100bb392a8d1966e230891aeaec658
3
- size 1783097336
 
 
 
 
output/ckpt-9999/training_state.json DELETED
@@ -1 +0,0 @@
1
- {"step": 10000}
 
 
output/events.out.tfevents.1626504033.t1v-n-0e7426e8-w-0.89661.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5b3b9f725bfa1e9befedd29c8c0319001a6ddc3597c6dfa30c754913531f26bc
3
- size 40
 
 
 
 
output/events.out.tfevents.1626504547.t1v-n-0e7426e8-w-0.93479.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1d3bd0981ae5d2bb0ac2ffef88a1eac66f198c1a58e207b37e216a9997428160
3
- size 40
 
 
 
 
output/events.out.tfevents.1626505238.t1v-n-0e7426e8-w-0.95128.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6b26b65cf4c438d4270d906ca1ed332fbe65f924ae82a22289dda08f95d5919f
3
- size 40
 
 
 
 
output/events.out.tfevents.1626506421.t1v-n-0e7426e8-w-0.96635.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d1a031810dfc4c6e7c52913e5261afc3fa3d5cf5a68695b76bbffd177b065e27
3
- size 40
 
 
 
 
output/events.out.tfevents.1626507299.t1v-n-0e7426e8-w-0.98584.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4874aada340bc85728ebb4b7f8329a0eb6618a19f0b646abb1f1b5f2e9fc84fe
3
- size 40
 
 
 
 
output/events.out.tfevents.1626508342.t1v-n-0e7426e8-w-0.101251.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ed227170f48707b46db61d657803869c0be10d350b75f29b0844a6ef8a9e0cd3
3
- size 40
 
 
 
 
output/flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8299c056e5ca07f93db2db052d61cb941710e0925c62486ee0c9775116e0a6bf
3
  size 891548548
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8e339d352658c4fae724883dc700cc559e7ab3eb7116139f6f0d187fe7720e1
3
  size 891548548
output/opt_state.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c912197fd24feea06a22802e5bfbd9935100bb392a8d1966e230891aeaec658
3
  size 1783097336
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f2eb4ce40eafe6435258b3761c281883b93221092ca701e0cd1f21b78264297
3
  size 1783097336
output/training_state.json CHANGED
@@ -1 +1 @@
1
- {"step": 10000}
 
1
+ {"step": 140001}
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:837e804cfcfee38ffdbb87dc80de834a7c5aec62634910e6b2514794f848bba2
3
  size 891650495
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a9f97e22703c1a5cf20353b9e859f377c5fa709e5a96ee15ad40d66674b67fa
3
  size 891650495
run.sh CHANGED
@@ -1,12 +1,12 @@
1
  #!/bin/bash
2
- export CUDA_VISIBLE_DEVICES=1
3
 
4
  MODEL="flax-community/t5-base-dutch"
5
  OUTPUT="./output"
6
 
7
- TRAIN="/home/yeb/cnnuxsum/cnnuxsum_train.json"
8
- VAL="/home/yeb/cnnuxsum/cnnuxsum_val.json"
9
- TEST="/home/yeb/cnnuxsum/cnnuxsum_test.json"
10
 
11
  mkdir -p "${OUTPUT}"
12
 
@@ -15,48 +15,28 @@ python ./run_summarization_flax.py \
15
  --learning_rate "5e-4" \
16
  --warmup_steps 500 \
17
  --do_train \
 
 
18
  --train_file "${TRAIN}" \
19
  --validation_file "${VAL}" \
20
  --test_file "${TEST}" \
21
- --max_train_samples 640000 \
22
- --max_eval_samples 512 \
23
- --max_predict_samples 64 \
24
  --text_column "complete_text" \
25
  --summary_column "summary_text" \
26
- --source_prefix "summarize: " \
27
  --max_source_length 1024 \
28
  --max_target_length 142 \
29
  --output_dir "${OUTPUT}" \
30
  --per_device_train_batch_size=8 \
31
- --per_device_eval_batch_size=2 \
32
  --overwrite_output_dir \
33
  --num_train_epochs="1" \
34
- --logging_steps="50" \
35
- --save_steps="2000" \
36
- --eval_steps="25000000" \
37
- --num_beams 4
38
-
39
- # \
40
- # --do_predict
41
- # --do_eval \
42
-
43
-
44
- # \
45
- # --prediction_debug \
46
- # --predict_with_generate
47
-
48
-
49
-
50
 
51
  # --source_prefix "summarize: " \
52
-
53
- # --lr_scheduler_type="constant" \
54
-
55
- # --task "summarization" \
56
- # --early_stopping "true" \
57
- # --length_penalty "2.0" \
58
- # --max_length 300 \
59
- # --min_length 75 \
60
- # --no_repeat_ngram_size 3 \
61
- # --num_beams 4 \
62
- # --prefix "summarize: " \
 
1
  #!/bin/bash
2
+ export CUDA_VISIBLE_DEVICES="1"
3
 
4
  MODEL="flax-community/t5-base-dutch"
5
  OUTPUT="./output"
6
 
7
+ TRAIN="/home/yeb/Developer/data/cnnuxsum/cnnuxsum_train.json"
8
+ VAL="/home/yeb/Developer/data/cnnuxsum/cnnuxsum_val.json"
9
+ TEST="/home/yeb/Developer/data/cnnuxsum/cnnuxsum_test.json"
10
 
11
  mkdir -p "${OUTPUT}"
12
 
 
15
  --learning_rate "5e-4" \
16
  --warmup_steps 500 \
17
  --do_train \
18
+ --do_predict \
19
+ --do_eval \
20
  --train_file "${TRAIN}" \
21
  --validation_file "${VAL}" \
22
  --test_file "${TEST}" \
23
+ --max_train_samples 1366592 \
24
+ --max_eval_samples 32 \
25
+ --max_predict_samples 8 \
26
  --text_column "complete_text" \
27
  --summary_column "summary_text" \
 
28
  --max_source_length 1024 \
29
  --max_target_length 142 \
30
  --output_dir "${OUTPUT}" \
31
  --per_device_train_batch_size=8 \
32
+ --per_device_eval_batch_size=8 \
33
  --overwrite_output_dir \
34
  --num_train_epochs="1" \
35
+ --logging_steps="100" \
36
+ --save_steps="20000" \
37
+ --eval_steps="5000" \
38
+ --num_beams 4 \
39
+ --prediction_debug \
40
+ --predict_with_generate
 
 
 
 
 
 
 
 
 
 
41
 
42
  # --source_prefix "summarize: " \
 
 
 
 
 
 
 
 
 
 
 
run_summarization_flax.py CHANGED
@@ -90,20 +90,34 @@ class ModelArguments:
90
  )
91
  model_type: Optional[str] = field(
92
  default=None,
93
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
 
 
 
94
  )
95
  config_name: Optional[str] = field(
96
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 
 
 
97
  )
98
  tokenizer_name: Optional[str] = field(
99
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
 
 
 
100
  )
101
  cache_dir: Optional[str] = field(
102
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
 
 
 
103
  )
104
  use_fast_tokenizer: bool = field(
105
  default=True,
106
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
 
 
107
  )
108
  dtype: Optional[str] = field(
109
  default="float32",
@@ -120,27 +134,41 @@ class DataTrainingArguments:
120
  """
121
 
122
  dataset_name: Optional[str] = field(
123
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
 
124
  )
125
  dataset_config_name: Optional[str] = field(
126
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
 
 
 
127
  )
128
  text_column: Optional[str] = field(
129
  default=None,
130
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
 
 
131
  )
132
  summary_column: Optional[str] = field(
133
  default=None,
134
- metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
 
 
 
 
 
135
  )
136
- train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
137
  validation_file: Optional[str] = field(
138
  default=None,
139
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
 
 
140
  )
141
  test_file: Optional[str] = field(
142
  default=None,
143
- metadata={"help": "An optional input evaluation data file to predict the perplexity on (a text file)."},
 
 
144
  )
145
  max_source_length: Optional[int] = field(
146
  default=1024,
@@ -191,10 +219,16 @@ class DataTrainingArguments:
191
  metadata={"help": "The number of processes to use for the preprocessing."},
192
  )
193
  source_prefix: Optional[str] = field(
194
- default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
 
 
 
195
  )
196
  predict_with_generate: bool = field(
197
- default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
 
 
 
198
  )
199
  num_beams: Optional[int] = field(
200
  default=None,
@@ -204,52 +238,52 @@ class DataTrainingArguments:
204
  },
205
  )
206
  overwrite_cache: bool = field(
207
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
208
  )
209
  prediction_debug: bool = field(
210
  default=False,
211
- metadata={
212
- "help": "Whether to show some examples of the model prediction"
213
- },
214
  )
215
 
216
  def __post_init__(self):
217
- if self.dataset_name is None and self.train_file is None and self.validation_file is None:
218
- raise ValueError("Need either a dataset name or a training/validation file.")
 
 
 
 
 
 
219
  else:
220
  if self.train_file is not None:
221
  extension = self.train_file.split(".")[-1]
222
- assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
 
 
 
223
  if self.validation_file is not None:
224
  extension = self.validation_file.split(".")[-1]
225
- assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
 
 
 
226
  if self.val_max_target_length is None:
227
  self.val_max_target_length = self.max_target_length
228
 
229
 
230
- summarization_name_mapping = {
231
- "amazon_reviews_multi": ("review_body", "review_title"),
232
- "big_patent": ("description", "abstract"),
233
- "cnn_dailymail": ("article", "highlights"),
234
- "orange_sum": ("text", "summary"),
235
- "pn_summary": ("article", "summary"),
236
- "psc": ("extract_text", "summary_text"),
237
- "samsum": ("dialogue", "summary"),
238
- "thaisum": ("body", "summary"),
239
- "xglue": ("news_body", "news_title"),
240
- "xsum": ("document", "summary"),
241
- "wiki_summary": ("article", "highlights"),
242
- }
243
-
244
-
245
  class TrainState(train_state.TrainState):
246
  dropout_rng: jnp.ndarray
247
 
248
  def replicate(self):
249
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
 
 
250
 
251
 
252
- def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
 
 
253
  """
254
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
255
  Shuffle batches if `shuffle` is `True`.
@@ -273,7 +307,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
273
  yield batch
274
 
275
 
276
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
277
  summary_writer.scalar("train_time", train_time, step)
278
 
279
  train_metrics = get_metrics(train_metrics)
@@ -282,21 +316,35 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
282
  for i, val in enumerate(vals):
283
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
284
 
 
 
285
  for metric_name, value in eval_metrics.items():
286
  summary_writer.scalar(f"eval_{metric_name}", value, step)
287
 
288
 
289
  def create_learning_rate_fn(
290
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
 
 
 
 
291
  ) -> Callable[[int], jnp.array]:
292
  """Returns a linear warmup, linear_decay learning rate function."""
293
  steps_per_epoch = train_ds_size // train_batch_size
294
  num_train_steps = steps_per_epoch * num_train_epochs
295
- warmup_fn = optax.linear_schedule(init_value=learning_rate, end_value=learning_rate, transition_steps=num_warmup_steps)
 
 
 
 
296
  decay_fn = optax.linear_schedule(
297
- init_value=learning_rate, end_value=learning_rate, transition_steps=num_train_steps - num_warmup_steps
 
 
 
 
 
298
  )
299
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
300
 
301
  return schedule_fn
302
 
@@ -306,11 +354,15 @@ def main():
306
  # or by passing the --help flag to this script.
307
  # We now keep distinct sets of args, for a cleaner separation of concerns.
308
 
309
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
 
 
310
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
311
  # If we pass only one argument to the script and it's the path to a json file,
312
  # let's parse it to get our arguments.
313
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
 
 
314
  else:
315
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
316
 
@@ -334,11 +386,7 @@ def main():
334
  state = jax_utils.unreplicate(state)
335
  logger.info(f"SAVING CHECKPOINT IN {save_dir}")
336
  save_dir = f"{save_dir}/ckpt-{mb_item(state.step) - 1}"
337
- model.save_pretrained(
338
- save_dir,
339
- params=state.params,
340
- push_to_hub=False
341
- )
342
  if with_opt:
343
  with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
344
  f.write(to_bytes(state.opt_state))
@@ -352,9 +400,13 @@ def main():
352
  # commit_message=f"Saving weights and logs of step {cur_step}",
353
  # )
354
  if with_opt:
355
- with open(os.path.join(training_args.output_dir, "opt_state.msgpack"), "wb") as f:
 
 
356
  f.write(to_bytes(state.opt_state))
357
- with open(os.path.join(training_args.output_dir, "training_state.json"), "w") as f:
 
 
358
  json.dump({"step": state.step.item()}, f)
359
  logger.info("checkpoint saved")
360
 
@@ -386,7 +438,10 @@ def main():
386
  if data_args.dataset_name is not None:
387
  # Downloading and loading a dataset from the hub.
388
  dataset = load_dataset(
389
- data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
 
 
 
390
  )
391
  else:
392
  data_files = {}
@@ -399,27 +454,37 @@ def main():
399
  if data_args.test_file is not None:
400
  data_files["test"] = data_args.test_file
401
  extension = data_args.test_file.split(".")[-1]
402
- dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
 
 
403
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
404
  # https://huggingface.co/docs/datasets/loading_datasets.html.
405
 
406
  # Load pretrained model and tokenizer
407
 
408
  if model_args.config_name:
409
- config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
 
 
410
  elif model_args.model_name_or_path:
411
- config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
 
 
412
  else:
413
  config = CONFIG_MAPPING[model_args.model_type]()
414
  logger.warning("You are instantiating a new config instance from scratch.")
415
 
416
  if model_args.tokenizer_name:
417
  tokenizer = AutoTokenizer.from_pretrained(
418
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
419
  )
420
  elif model_args.model_name_or_path:
421
  tokenizer = AutoTokenizer.from_pretrained(
422
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
423
  )
424
  else:
425
  raise ValueError(
@@ -429,7 +494,10 @@ def main():
429
 
430
  if model_args.model_name_or_path:
431
  model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
432
- model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
 
 
 
433
  )
434
  else:
435
  model = FlaxAutoModelForSeq2SeqLM.from_config(
@@ -437,7 +505,9 @@ def main():
437
  )
438
 
439
  if model.config.decoder_start_token_id is None:
440
- raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
 
 
441
 
442
  prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
443
 
@@ -450,13 +520,14 @@ def main():
450
  elif training_args.do_predict:
451
  column_names = dataset["test"].column_names
452
  else:
453
- logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
 
 
454
  return
455
 
456
  # Get the column names for input/target.
457
- dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
458
  if data_args.text_column is None:
459
- text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
460
  else:
461
  text_column = data_args.text_column
462
  if text_column not in column_names:
@@ -464,7 +535,7 @@ def main():
464
  f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
465
  )
466
  if data_args.summary_column is None:
467
- summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
468
  else:
469
  summary_column = data_args.summary_column
470
  if summary_column not in column_names:
@@ -487,18 +558,28 @@ def main():
487
  targets = examples[summary_column]
488
  inputs = [prefix + inp for inp in inputs]
489
  model_inputs = tokenizer(
490
- inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
 
 
 
 
491
  )
492
 
493
  # Setup the tokenizer for targets
494
  with tokenizer.as_target_tokenizer():
495
  labels = tokenizer(
496
- targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
 
 
 
 
497
  )
498
 
499
  model_inputs["labels"] = labels["input_ids"]
500
  decoder_input_ids = shift_tokens_right_fn(
501
- jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
 
 
502
  )
503
  model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
504
 
@@ -544,7 +625,9 @@ def main():
544
  raise ValueError("--do_predict requires a test dataset")
545
  predict_dataset = dataset["test"]
546
  if data_args.max_predict_samples is not None:
547
- predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
 
 
548
  predict_dataset = predict_dataset.map(
549
  preprocess_function,
550
  batched=True,
@@ -553,6 +636,14 @@ def main():
553
  load_from_cache_file=not data_args.overwrite_cache,
554
  desc="Running tokenizer on prediction dataset",
555
  )
 
 
 
 
 
 
 
 
556
 
557
  # Metric
558
  metric = load_metric("rouge")
@@ -578,13 +669,28 @@ def main():
578
  for index in random.sample(range(len(decoded_labels)), 3):
579
  logger.info(f'reference: "{decoded_labels[index]}"')
580
  logger.info(f'predicted: "{decoded_preds[index]}"')
581
- logger.info('---')
582
 
583
- result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
 
 
584
  # Extract a few results from ROUGE
585
  result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
586
 
587
- prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  result["gen_len"] = np.mean(prediction_lens)
589
  result = {k: round(v, 4) for k, v in result.items()}
590
  return result
@@ -595,7 +701,7 @@ def main():
595
  try:
596
  from flax.metrics.tensorboard import SummaryWriter
597
 
598
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
599
  except ImportError as ie:
600
  has_tensorboard = False
601
  logger.warning(
@@ -613,7 +719,9 @@ def main():
613
 
614
  # Store some constant
615
  num_epochs = int(training_args.num_train_epochs)
616
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
 
 
617
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
618
  steps_per_epoch = len(train_dataset) // train_batch_size
619
  total_train_steps = steps_per_epoch * num_epochs
@@ -634,13 +742,36 @@ def main():
634
  # Note that this mask is specifically adapted for FlaxBart.
635
  # For FlaxT5, one should correct the layer norm parameter naming
636
  # accordingly - see `run_t5_mlm_flax.py` e.g.
637
- def decay_mask_fn(params):
638
- flat_params = traverse_util.flatten_dict(params)
639
- layer_norm_params = [
640
- (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
641
- ]
642
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
643
- return traverse_util.unflatten_dict(flat_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
 
645
  # create adam optimizer
646
  adamw = optax.adamw(
@@ -653,7 +784,9 @@ def main():
653
  )
654
 
655
  # Setup train state
656
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
 
 
657
 
658
  # label smoothed cross entropy
659
  def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
@@ -665,9 +798,12 @@ def main():
665
  confidence = 1.0 - label_smoothing_factor
666
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
667
  normalizing_constant = -(
668
- confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
 
 
 
 
669
  )
670
- soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
671
 
672
  loss = optax.softmax_cross_entropy(logits, soft_labels)
673
  loss = loss - normalizing_constant
@@ -683,8 +819,12 @@ def main():
683
 
684
  def compute_loss(params):
685
  labels = batch.pop("labels")
686
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
687
- loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
 
 
 
 
688
  return loss
689
 
690
  grad_fn = jax.value_and_grad(compute_loss)
@@ -693,7 +833,10 @@ def main():
693
 
694
  new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
695
 
696
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
 
 
 
697
  metrics = jax.lax.pmean(metrics, axis_name="batch")
698
 
699
  return new_state, metrics
@@ -702,7 +845,9 @@ def main():
702
  def eval_step(params, batch, label_smoothing_factor=0.0):
703
  labels = batch.pop("labels")
704
  logits = model(**batch, params=params, train=False)[0]
705
- loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
 
 
706
 
707
  # summarize metrics
708
  metrics = {"loss": loss}
@@ -711,21 +856,36 @@ def main():
711
 
712
  # Define generation function
713
  max_length = (
714
- data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
 
 
 
 
 
 
 
715
  )
716
- num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
717
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
718
 
719
  def generate_step(params, batch):
720
  model.params = params
721
- output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
 
 
722
  return output_ids.sequences
723
 
724
  # Create parallel version of the train and eval step
725
  p_train_step = jax.pmap(
726
- partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
 
 
 
 
 
 
 
 
727
  )
728
- p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
729
  p_generate_step = jax.pmap(generate_step, "batch")
730
 
731
  # Replicate the train state on each device
@@ -734,11 +894,16 @@ def main():
734
  logger.info("***** Running training *****")
735
  logger.info(f" Num examples = {len(train_dataset)}")
736
  logger.info(f" Num Epochs = {num_epochs}")
737
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
738
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
 
 
 
 
739
  logger.info(f" Total optimization steps = {total_train_steps}")
740
 
741
  train_time = 0
 
742
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
743
  for epoch in epochs:
744
  # ======================== Training ================================
@@ -746,117 +911,160 @@ def main():
746
 
747
  # Create sampling rng
748
  rng, input_rng = jax.random.split(rng)
749
- train_metrics = []
750
 
751
  # Generate an epoch by shuffling sampling indices from the train dataset
752
- train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
 
 
753
  steps_per_epoch = len(train_dataset) // train_batch_size
754
  # train
755
- for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
 
 
756
  batch = next(train_loader)
757
  state, train_metric = p_train_step(state, batch)
758
  train_metrics.append(train_metric)
759
 
760
- train_time += time.time() - train_start
761
-
762
- train_metric = unreplicate(train_metric)
763
-
764
- epochs.write(
765
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
766
- )
767
-
768
- # save checkpoint after each epoch and push checkpoint to the hub
769
- if jax.process_index() == 0:
770
- # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
771
- # model.save_pretrained(
772
- # training_args.output_dir,
773
- # params=params,
774
- # push_to_hub=training_args.push_to_hub,
775
- # commit_message=f"Saving weights and logs of epoch {epoch+1}",
776
- # )
777
- save_checkpoint(model, training_args.output_dir, state)
778
-
779
- # ======================== Evaluating ==============================
780
- if training_args.do_eval:
781
- eval_metrics = []
782
- eval_preds = []
783
- eval_labels = []
784
-
785
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
786
- eval_steps = len(eval_dataset) // eval_batch_size
787
- for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
788
- # Model forward
789
- batch = next(eval_loader)
790
- labels = batch["labels"]
791
-
792
- metrics = p_eval_step(state.params, batch)
793
- eval_metrics.append(metrics)
794
-
795
- # generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
  if data_args.predict_with_generate:
797
- generated_ids = p_generate_step(state.params, batch)
798
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
799
- eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
800
-
801
- # normalize eval metrics
802
- eval_metrics = get_metrics(eval_metrics)
803
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
804
-
805
- # compute ROUGE metrics
806
- rouge_desc = ""
807
- if data_args.predict_with_generate:
808
- rouge_metrics = compute_metrics(eval_preds, eval_labels)
809
- eval_metrics.update(rouge_metrics)
810
- rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
811
-
812
- # Print metrics and update progress bar
813
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
814
- epochs.write(desc)
815
- epochs.desc = desc
816
-
817
- # Save metrics
818
- if has_tensorboard and jax.process_index() == 0:
819
- cur_step = epoch * (len(train_dataset) // train_batch_size)
820
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
821
-
822
- # ======================== Prediction loop ==============================
823
- if training_args.do_predict:
824
- logger.info("*** Predict ***")
825
-
826
- pred_metrics = []
827
- pred_generations = []
828
- pred_labels = []
829
-
830
- pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
831
- pred_steps = len(predict_dataset) // eval_batch_size
832
- for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
833
- # Model forward
834
- batch = next(pred_loader)
835
- labels = batch["labels"]
836
-
837
- metrics = p_eval_step(state.params, batch)
838
- pred_metrics.append(metrics)
839
-
840
- # generation
841
- if data_args.predict_with_generate:
842
- generated_ids = p_generate_step(state.params, batch)
843
- pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
844
- pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
845
-
846
- # normalize prediction metrics
847
- pred_metrics = get_metrics(pred_metrics)
848
- pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
849
-
850
- # compute ROUGE metrics
851
- rouge_desc = ""
852
- if data_args.predict_with_generate:
853
- rouge_metrics = compute_metrics(pred_generations, pred_labels)
854
- pred_metrics.update(rouge_metrics)
855
- rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
856
-
857
- # Print metrics
858
- desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
859
- logger.info(desc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860
 
861
  # save checkpoint after each epoch and push checkpoint to the hub
862
  if jax.process_index() == 0:
@@ -867,8 +1075,6 @@ def main():
867
  push_to_hub=training_args.push_to_hub,
868
  commit_message=f"Saving weights and logs of epoch {epoch+1}",
869
  )
870
- # save_checkpoint(model, training_args.output_dir, state)
871
-
872
 
873
 
874
  if __name__ == "__main__":
 
90
  )
91
  model_type: Optional[str] = field(
92
  default=None,
93
+ metadata={
94
+ "help": "If training from scratch, pass a model type from the list: "
95
+ + ", ".join(MODEL_TYPES)
96
+ },
97
  )
98
  config_name: Optional[str] = field(
99
+ default=None,
100
+ metadata={
101
+ "help": "Pretrained config name or path if not the same as model_name"
102
+ },
103
  )
104
  tokenizer_name: Optional[str] = field(
105
+ default=None,
106
+ metadata={
107
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
108
+ },
109
  )
110
  cache_dir: Optional[str] = field(
111
+ default=None,
112
+ metadata={
113
+ "help": "Where do you want to store the pretrained models downloaded from s3"
114
+ },
115
  )
116
  use_fast_tokenizer: bool = field(
117
  default=True,
118
+ metadata={
119
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
120
+ },
121
  )
122
  dtype: Optional[str] = field(
123
  default="float32",
 
134
  """
135
 
136
  dataset_name: Optional[str] = field(
137
+ default=None,
138
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
139
  )
140
  dataset_config_name: Optional[str] = field(
141
+ default=None,
142
+ metadata={
143
+ "help": "The configuration name of the dataset to use (via the datasets library)."
144
+ },
145
  )
146
  text_column: Optional[str] = field(
147
  default=None,
148
+ metadata={
149
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."
150
+ },
151
  )
152
  summary_column: Optional[str] = field(
153
  default=None,
154
+ metadata={
155
+ "help": "The name of the column in the datasets containing the summaries (for summarization)."
156
+ },
157
+ )
158
+ train_file: Optional[str] = field(
159
+ default=None, metadata={"help": "The input training data file (a text file)."}
160
  )
 
161
  validation_file: Optional[str] = field(
162
  default=None,
163
+ metadata={
164
+ "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
165
+ },
166
  )
167
  test_file: Optional[str] = field(
168
  default=None,
169
+ metadata={
170
+ "help": "An optional input evaluation data file to predict the perplexity on (a text file)."
171
+ },
172
  )
173
  max_source_length: Optional[int] = field(
174
  default=1024,
 
219
  metadata={"help": "The number of processes to use for the preprocessing."},
220
  )
221
  source_prefix: Optional[str] = field(
222
+ default=None,
223
+ metadata={
224
+ "help": "A prefix to add before every source text (useful for T5 models)."
225
+ },
226
  )
227
  predict_with_generate: bool = field(
228
+ default=False,
229
+ metadata={
230
+ "help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."
231
+ },
232
  )
233
  num_beams: Optional[int] = field(
234
  default=None,
 
238
  },
239
  )
240
  overwrite_cache: bool = field(
241
+ default=False,
242
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
243
  )
244
  prediction_debug: bool = field(
245
  default=False,
246
+ metadata={"help": "Whether to show some examples of the model prediction"},
 
 
247
  )
248
 
249
  def __post_init__(self):
250
+ if (
251
+ self.dataset_name is None
252
+ and self.train_file is None
253
+ and self.validation_file is None
254
+ ):
255
+ raise ValueError(
256
+ "Need either a dataset name or a training/validation file."
257
+ )
258
  else:
259
  if self.train_file is not None:
260
  extension = self.train_file.split(".")[-1]
261
+ assert extension in [
262
+ "csv",
263
+ "json",
264
+ ], "`train_file` should be a csv or a json file."
265
  if self.validation_file is not None:
266
  extension = self.validation_file.split(".")[-1]
267
+ assert extension in [
268
+ "csv",
269
+ "json",
270
+ ], "`validation_file` should be a csv or a json file."
271
  if self.val_max_target_length is None:
272
  self.val_max_target_length = self.max_target_length
273
 
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  class TrainState(train_state.TrainState):
276
  dropout_rng: jnp.ndarray
277
 
278
  def replicate(self):
279
+ return jax_utils.replicate(self).replace(
280
+ dropout_rng=shard_prng_key(self.dropout_rng)
281
+ )
282
 
283
 
284
+ def data_loader(
285
+ rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False
286
+ ):
287
  """
288
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
289
  Shuffle batches if `shuffle` is `True`.
 
307
  yield batch
308
 
309
 
310
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
311
  summary_writer.scalar("train_time", train_time, step)
312
 
313
  train_metrics = get_metrics(train_metrics)
 
316
  for i, val in enumerate(vals):
317
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
318
 
319
+
320
+ def write_eval_metric(summary_writer, eval_metrics, step):
321
  for metric_name, value in eval_metrics.items():
322
  summary_writer.scalar(f"eval_{metric_name}", value, step)
323
 
324
 
325
  def create_learning_rate_fn(
326
+ train_ds_size: int,
327
+ train_batch_size: int,
328
+ num_train_epochs: int,
329
+ num_warmup_steps: int,
330
+ learning_rate: float,
331
  ) -> Callable[[int], jnp.array]:
332
  """Returns a linear warmup, linear_decay learning rate function."""
333
  steps_per_epoch = train_ds_size // train_batch_size
334
  num_train_steps = steps_per_epoch * num_train_epochs
335
+ warmup_fn = optax.linear_schedule(
336
+ init_value=learning_rate,
337
+ end_value=learning_rate,
338
+ transition_steps=num_warmup_steps,
339
+ )
340
  decay_fn = optax.linear_schedule(
341
+ init_value=learning_rate,
342
+ end_value=learning_rate,
343
+ transition_steps=num_train_steps - num_warmup_steps,
344
+ )
345
+ schedule_fn = optax.join_schedules(
346
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
347
  )
 
348
 
349
  return schedule_fn
350
 
 
354
  # or by passing the --help flag to this script.
355
  # We now keep distinct sets of args, for a cleaner separation of concerns.
356
 
357
+ parser = HfArgumentParser(
358
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
359
+ )
360
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
361
  # If we pass only one argument to the script and it's the path to a json file,
362
  # let's parse it to get our arguments.
363
+ model_args, data_args, training_args = parser.parse_json_file(
364
+ json_file=os.path.abspath(sys.argv[1])
365
+ )
366
  else:
367
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
368
 
 
386
  state = jax_utils.unreplicate(state)
387
  logger.info(f"SAVING CHECKPOINT IN {save_dir}")
388
  save_dir = f"{save_dir}/ckpt-{mb_item(state.step) - 1}"
389
+ model.save_pretrained(save_dir, params=state.params, push_to_hub=False)
 
 
 
 
390
  if with_opt:
391
  with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
392
  f.write(to_bytes(state.opt_state))
 
400
  # commit_message=f"Saving weights and logs of step {cur_step}",
401
  # )
402
  if with_opt:
403
+ with open(
404
+ os.path.join(training_args.output_dir, "opt_state.msgpack"), "wb"
405
+ ) as f:
406
  f.write(to_bytes(state.opt_state))
407
+ with open(
408
+ os.path.join(training_args.output_dir, "training_state.json"), "w"
409
+ ) as f:
410
  json.dump({"step": state.step.item()}, f)
411
  logger.info("checkpoint saved")
412
 
 
438
  if data_args.dataset_name is not None:
439
  # Downloading and loading a dataset from the hub.
440
  dataset = load_dataset(
441
+ data_args.dataset_name,
442
+ data_args.dataset_config_name,
443
+ cache_dir=model_args.cache_dir,
444
+ keep_in_memory=False,
445
  )
446
  else:
447
  data_files = {}
 
454
  if data_args.test_file is not None:
455
  data_files["test"] = data_args.test_file
456
  extension = data_args.test_file.split(".")[-1]
457
+ dataset = load_dataset(
458
+ extension, data_files=data_files, cache_dir=model_args.cache_dir
459
+ )
460
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
461
  # https://huggingface.co/docs/datasets/loading_datasets.html.
462
 
463
  # Load pretrained model and tokenizer
464
 
465
  if model_args.config_name:
466
+ config = AutoConfig.from_pretrained(
467
+ model_args.config_name, cache_dir=model_args.cache_dir
468
+ )
469
  elif model_args.model_name_or_path:
470
+ config = AutoConfig.from_pretrained(
471
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir
472
+ )
473
  else:
474
  config = CONFIG_MAPPING[model_args.model_type]()
475
  logger.warning("You are instantiating a new config instance from scratch.")
476
 
477
  if model_args.tokenizer_name:
478
  tokenizer = AutoTokenizer.from_pretrained(
479
+ model_args.tokenizer_name,
480
+ cache_dir=model_args.cache_dir,
481
+ use_fast=model_args.use_fast_tokenizer,
482
  )
483
  elif model_args.model_name_or_path:
484
  tokenizer = AutoTokenizer.from_pretrained(
485
+ model_args.model_name_or_path,
486
+ cache_dir=model_args.cache_dir,
487
+ use_fast=model_args.use_fast_tokenizer,
488
  )
489
  else:
490
  raise ValueError(
 
494
 
495
  if model_args.model_name_or_path:
496
  model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
497
+ model_args.model_name_or_path,
498
+ config=config,
499
+ seed=training_args.seed,
500
+ dtype=getattr(jnp, model_args.dtype),
501
  )
502
  else:
503
  model = FlaxAutoModelForSeq2SeqLM.from_config(
 
505
  )
506
 
507
  if model.config.decoder_start_token_id is None:
508
+ raise ValueError(
509
+ "Make sure that `config.decoder_start_token_id` is correctly defined"
510
+ )
511
 
512
  prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
513
 
 
520
  elif training_args.do_predict:
521
  column_names = dataset["test"].column_names
522
  else:
523
+ logger.info(
524
+ "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`."
525
+ )
526
  return
527
 
528
  # Get the column names for input/target.
 
529
  if data_args.text_column is None:
530
+ text_column = column_names[0]
531
  else:
532
  text_column = data_args.text_column
533
  if text_column not in column_names:
 
535
  f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
536
  )
537
  if data_args.summary_column is None:
538
+ summary_column = column_names[1]
539
  else:
540
  summary_column = data_args.summary_column
541
  if summary_column not in column_names:
 
558
  targets = examples[summary_column]
559
  inputs = [prefix + inp for inp in inputs]
560
  model_inputs = tokenizer(
561
+ inputs,
562
+ max_length=data_args.max_source_length,
563
+ padding="max_length",
564
+ truncation=True,
565
+ return_tensors="np",
566
  )
567
 
568
  # Setup the tokenizer for targets
569
  with tokenizer.as_target_tokenizer():
570
  labels = tokenizer(
571
+ targets,
572
+ max_length=max_target_length,
573
+ padding="max_length",
574
+ truncation=True,
575
+ return_tensors="np",
576
  )
577
 
578
  model_inputs["labels"] = labels["input_ids"]
579
  decoder_input_ids = shift_tokens_right_fn(
580
+ jnp.array(labels["input_ids"]),
581
+ config.pad_token_id,
582
+ config.decoder_start_token_id,
583
  )
584
  model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
585
 
 
625
  raise ValueError("--do_predict requires a test dataset")
626
  predict_dataset = dataset["test"]
627
  if data_args.max_predict_samples is not None:
628
+ predict_dataset = predict_dataset.select(
629
+ range(data_args.max_predict_samples)
630
+ )
631
  predict_dataset = predict_dataset.map(
632
  preprocess_function,
633
  batched=True,
 
636
  load_from_cache_file=not data_args.overwrite_cache,
637
  desc="Running tokenizer on prediction dataset",
638
  )
639
+ eval_batch_size = (
640
+ int(training_args.per_device_eval_batch_size) * jax.device_count()
641
+ )
642
+ pred_steps = len(predict_dataset) // eval_batch_size
643
+ if pred_steps == 0:
644
+ raise Exception(
645
+ "The length of the prediction dataset // eval batch size is 0. Increase prediction dataset size"
646
+ )
647
 
648
  # Metric
649
  metric = load_metric("rouge")
 
669
  for index in random.sample(range(len(decoded_labels)), 3):
670
  logger.info(f'reference: "{decoded_labels[index]}"')
671
  logger.info(f'predicted: "{decoded_preds[index]}"')
672
+ logger.info("---")
673
 
674
+ result = metric.compute(
675
+ predictions=decoded_preds, references=decoded_labels, use_stemmer=True
676
+ )
677
  # Extract a few results from ROUGE
678
  result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
679
 
680
+ try:
681
+ result_blue = bleu.compute(
682
+ predictions=decoded_preds, references=decoded_labels_bleu
683
+ )
684
+ result_blue = result_blue["score"]
685
+ except Exception as e:
686
+ logger.info(f"Error occurred during bleu {e}")
687
+ result_blue = 0.0 * 100
688
+ result["blue"] = result_blue
689
+
690
+
691
+ prediction_lens = [
692
+ np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
693
+ ]
694
  result["gen_len"] = np.mean(prediction_lens)
695
  result = {k: round(v, 4) for k, v in result.items()}
696
  return result
 
701
  try:
702
  from flax.metrics.tensorboard import SummaryWriter
703
 
704
+ summary_writer = SummaryWriter(log_dir=Path(training_args.logging_dir))
705
  except ImportError as ie:
706
  has_tensorboard = False
707
  logger.warning(
 
719
 
720
  # Store some constant
721
  num_epochs = int(training_args.num_train_epochs)
722
+ train_batch_size = (
723
+ int(training_args.per_device_train_batch_size) * jax.device_count()
724
+ )
725
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
726
  steps_per_epoch = len(train_dataset) // train_batch_size
727
  total_train_steps = steps_per_epoch * num_epochs
 
742
  # Note that this mask is specifically adapted for FlaxBart.
743
  # For FlaxT5, one should correct the layer norm parameter naming
744
  # accordingly - see `run_t5_mlm_flax.py` e.g.
745
+ if config.model_type in ["t5", "mt5", "byt5"]:
746
+
747
+ def decay_mask_fn(params):
748
+ flat_params = traverse_util.flatten_dict(params)
749
+ layer_norm_params = [
750
+ (name, "scale") for name in ["layer_norm", "final_layer_norm"]
751
+ ]
752
+ flat_mask = {
753
+ path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
754
+ for path in flat_params
755
+ }
756
+ return traverse_util.unflatten_dict(flat_mask)
757
+
758
+ else:
759
+
760
+ def decay_mask_fn(params):
761
+ flat_params = traverse_util.flatten_dict(params)
762
+ layer_norm_params = [
763
+ (name, "scale")
764
+ for name in [
765
+ "self_attn_layer_norm",
766
+ "layernorm_embedding",
767
+ "final_layer_norm",
768
+ ]
769
+ ]
770
+ flat_mask = {
771
+ path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
772
+ for path in flat_params
773
+ }
774
+ return traverse_util.unflatten_dict(flat_mask)
775
 
776
  # create adam optimizer
777
  adamw = optax.adamw(
 
784
  )
785
 
786
  # Setup train state
787
+ state = TrainState.create(
788
+ apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng
789
+ )
790
 
791
  # label smoothed cross entropy
792
  def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
 
798
  confidence = 1.0 - label_smoothing_factor
799
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
800
  normalizing_constant = -(
801
+ confidence * jnp.log(confidence)
802
+ + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
803
+ )
804
+ soft_labels = onehot(
805
+ labels, vocab_size, on_value=confidence, off_value=low_confidence
806
  )
 
807
 
808
  loss = optax.softmax_cross_entropy(logits, soft_labels)
809
  loss = loss - normalizing_constant
 
819
 
820
  def compute_loss(params):
821
  labels = batch.pop("labels")
822
+ logits = state.apply_fn(
823
+ **batch, params=params, dropout_rng=dropout_rng, train=True
824
+ )[0]
825
+ loss = loss_fn(
826
+ logits, labels, batch["decoder_attention_mask"], label_smoothing_factor
827
+ )
828
  return loss
829
 
830
  grad_fn = jax.value_and_grad(compute_loss)
 
833
 
834
  new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
835
 
836
+ metrics = {
837
+ "loss": loss,
838
+ "learning_rate": linear_decay_lr_schedule_fn(state.step),
839
+ }
840
  metrics = jax.lax.pmean(metrics, axis_name="batch")
841
 
842
  return new_state, metrics
 
845
  def eval_step(params, batch, label_smoothing_factor=0.0):
846
  labels = batch.pop("labels")
847
  logits = model(**batch, params=params, train=False)[0]
848
+ loss = loss_fn(
849
+ logits, labels, batch["decoder_attention_mask"], label_smoothing_factor
850
+ )
851
 
852
  # summarize metrics
853
  metrics = {"loss": loss}
 
856
 
857
  # Define generation function
858
  max_length = (
859
+ data_args.val_max_target_length
860
+ if data_args.val_max_target_length is not None
861
+ else model.config.max_length
862
+ )
863
+ num_beams = (
864
+ data_args.num_beams
865
+ if data_args.num_beams is not None
866
+ else model.config.num_beams
867
  )
 
868
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
869
 
870
  def generate_step(params, batch):
871
  model.params = params
872
+ output_ids = model.generate(
873
+ batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs
874
+ )
875
  return output_ids.sequences
876
 
877
  # Create parallel version of the train and eval step
878
  p_train_step = jax.pmap(
879
+ partial(
880
+ train_step, label_smoothing_factor=training_args.label_smoothing_factor
881
+ ),
882
+ "batch",
883
+ donate_argnums=(0,),
884
+ )
885
+ p_eval_step = jax.pmap(
886
+ partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor),
887
+ "batch",
888
  )
 
889
  p_generate_step = jax.pmap(generate_step, "batch")
890
 
891
  # Replicate the train state on each device
 
894
  logger.info("***** Running training *****")
895
  logger.info(f" Num examples = {len(train_dataset)}")
896
  logger.info(f" Num Epochs = {num_epochs}")
897
+ logger.info(
898
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
899
+ )
900
+ logger.info(
901
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
902
+ )
903
  logger.info(f" Total optimization steps = {total_train_steps}")
904
 
905
  train_time = 0
906
+ train_metrics = []
907
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
908
  for epoch in epochs:
909
  # ======================== Training ================================
 
911
 
912
  # Create sampling rng
913
  rng, input_rng = jax.random.split(rng)
 
914
 
915
  # Generate an epoch by shuffling sampling indices from the train dataset
916
+ train_loader = data_loader(
917
+ input_rng, train_dataset, train_batch_size, shuffle=True
918
+ )
919
  steps_per_epoch = len(train_dataset) // train_batch_size
920
  # train
921
+ for step in tqdm(
922
+ range(steps_per_epoch), desc="Training...", position=1, leave=False
923
+ ):
924
  batch = next(train_loader)
925
  state, train_metric = p_train_step(state, batch)
926
  train_metrics.append(train_metric)
927
 
928
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
929
+
930
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
931
+ # Save metrics
932
+ train_metric = unreplicate(train_metric)
933
+ train_time += time.time() - train_start
934
+
935
+ if has_tensorboard and jax.process_index() == 0:
936
+ logger.info(
937
+ f"*** Writing training summary after {cur_step} steps ***"
938
+ )
939
+ write_train_metric(
940
+ summary_writer, train_metrics, train_time, cur_step
941
+ )
942
+
943
+ epochs.write(
944
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
945
+ )
946
+
947
+ train_metrics = []
948
+
949
+ if (
950
+ training_args.do_eval
951
+ and cur_step % training_args.eval_steps == 0
952
+ and cur_step > 0
953
+ ):
954
+ logger.info(f"*** Evaluation after {cur_step} steps ***")
955
+ eval_metrics = []
956
+ eval_preds = []
957
+ eval_labels = []
958
+
959
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
960
+ eval_steps = len(eval_dataset) // eval_batch_size
961
+ for _ in tqdm(
962
+ range(eval_steps), desc="Evaluating...", position=2, leave=False
963
+ ):
964
+ # Model forward
965
+ batch = next(eval_loader)
966
+ labels = batch["labels"]
967
+
968
+ metrics = p_eval_step(state.params, batch)
969
+ eval_metrics.append(metrics)
970
+
971
+ # generation
972
+ if data_args.predict_with_generate:
973
+ generated_ids = p_generate_step(state.params, batch)
974
+ eval_preds.extend(
975
+ jax.device_get(
976
+ generated_ids.reshape(-1, gen_kwargs["max_length"])
977
+ )
978
+ )
979
+ eval_labels.extend(
980
+ jax.device_get(labels.reshape(-1, labels.shape[-1]))
981
+ )
982
+
983
+ # normalize eval metrics
984
+ eval_metrics = get_metrics(eval_metrics)
985
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
986
+
987
+ # compute several metrics
988
+ mix_desc = ""
989
  if data_args.predict_with_generate:
990
+ mix_metrics = compute_metrics(eval_preds, eval_labels)
991
+ eval_metrics.update(mix_metrics)
992
+ mix_desc = " ".join(
993
+ [f"Eval {key}: {value} |" for key, value in mix_metrics.items()]
994
+ )
995
+
996
+ # Print metrics and update progress bar
997
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {mix_desc} )"
998
+ epochs.write(desc)
999
+ epochs.desc = desc
1000
+
1001
+ # Save metrics
1002
+ if has_tensorboard and jax.process_index() == 0:
1003
+ logger.info(
1004
+ f"*** Writing evaluation summary after {cur_step} steps ***"
1005
+ )
1006
+ # cur_step = epoch * (len(train_dataset) // train_batch_size)
1007
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
1008
+
1009
+ # ======================== Prediction loop ==============================
1010
+ if training_args.do_predict:
1011
+ logger.info("*** Predict ***")
1012
+
1013
+ pred_metrics = []
1014
+ pred_generations = []
1015
+ pred_labels = []
1016
+
1017
+ pred_loader = data_loader(
1018
+ input_rng, predict_dataset, eval_batch_size
1019
+ )
1020
+ pred_steps = len(predict_dataset) // eval_batch_size
1021
+ for _ in tqdm(
1022
+ range(pred_steps), desc="Predicting...", position=2, leave=False
1023
+ ):
1024
+ # Model forward
1025
+ batch = next(pred_loader)
1026
+ labels = batch["labels"]
1027
+
1028
+ metrics = p_eval_step(state.params, batch)
1029
+ pred_metrics.append(metrics)
1030
+
1031
+ # generation
1032
+ if data_args.predict_with_generate:
1033
+ generated_ids = p_generate_step(state.params, batch)
1034
+ pred_generations.extend(
1035
+ jax.device_get(
1036
+ generated_ids.reshape(-1, gen_kwargs["max_length"])
1037
+ )
1038
+ )
1039
+ pred_labels.extend(
1040
+ jax.device_get(labels.reshape(-1, labels.shape[-1]))
1041
+ )
1042
+
1043
+ # normalize prediction metrics
1044
+ pred_metrics = get_metrics(pred_metrics)
1045
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
1046
+
1047
+ # compute ROUGE metrics
1048
+ rouge_desc = ""
1049
+ if data_args.predict_with_generate:
1050
+ rouge_metrics = compute_metrics(pred_generations, pred_labels)
1051
+ pred_metrics.update(rouge_metrics)
1052
+ rouge_desc = " ".join(
1053
+ [
1054
+ f"Predict {key}: {value} |"
1055
+ for key, value in rouge_metrics.items()
1056
+ ]
1057
+ )
1058
+
1059
+ # Print metrics
1060
+ desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
1061
+ logger.info(desc)
1062
+
1063
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
1064
+ logger.info(f"*** Saving checkpoints after {cur_step} steps ***")
1065
+ # save checkpoint after each steps and push checkpoint to the hub
1066
+ if jax.process_index() == 0:
1067
+ save_checkpoint(model, training_args.output_dir, state)
1068
 
1069
  # save checkpoint after each epoch and push checkpoint to the hub
1070
  if jax.process_index() == 0:
 
1075
  push_to_hub=training_args.push_to_hub,
1076
  commit_message=f"Saving weights and logs of epoch {epoch+1}",
1077
  )
 
 
1078
 
1079
 
1080
  if __name__ == "__main__":