seduerr commited on
Commit
2aafb7c
1 Parent(s): 2e76937

Update from seduerr

Browse files
Files changed (40) hide show
  1. .DS_Store +0 -0
  2. README.md +0 -55
  3. baseline.png +0 -0
  4. convert_to_pytorch.py +0 -3
  5. convert_to_tensorflow.py +0 -3
  6. events.out.tfevents.1625592008.t1v-n-6586652e-w-0.376816.3.v2 +0 -0
  7. events.out.tfevents.1625592046.t1v-n-6586652e-w-0.378250.3.v2 +0 -0
  8. events.out.tfevents.1625592624.t1v-n-6586652e-w-0.380500.3.v2 +0 -0
  9. events.out.tfevents.1625593313.t1v-n-6586652e-w-0.382633.3.v2 +0 -0
  10. events.out.tfevents.1625626842.t1v-n-6586652e-w-0.33177.3.v2 +0 -0
  11. events.out.tfevents.1625627167.t1v-n-6586652e-w-0.34916.3.v2 +0 -0
  12. events.out.tfevents.1625629237.t1v-n-6586652e-w-0.38427.3.v2 +0 -0
  13. events.out.tfevents.1625629547.t1v-n-6586652e-w-0.40115.3.v2 +0 -0
  14. events.out.tfevents.1625629724.t1v-n-6586652e-w-0.41718.3.v2 +0 -0
  15. events.out.tfevents.1625629983.t1v-n-6586652e-w-0.43435.3.v2 +0 -0
  16. events.out.tfevents.1625630238.t1v-n-6586652e-w-0.45070.3.v2 +0 -0
  17. events.out.tfevents.1625630522.t1v-n-6586652e-w-0.46805.3.v2 +0 -0
  18. events.out.tfevents.1625630799.t1v-n-6586652e-w-0.48498.3.v2 +0 -0
  19. events.out.tfevents.1625631065.t1v-n-6586652e-w-0.50228.3.v2 +0 -0
  20. events.out.tfevents.1625631515.t1v-n-6586652e-w-0.52068.3.v2 +0 -0
  21. events.out.tfevents.1625631789.t1v-n-6586652e-w-0.53785.3.v2 +0 -0
  22. events.out.tfevents.1625633047.t1v-n-6586652e-w-0.63010.3.v2 +0 -0
  23. events.out.tfevents.1625633391.t1v-n-6586652e-w-0.64789.3.v2 +0 -0
  24. events.out.tfevents.1625633623.t1v-n-6586652e-w-0.66428.3.v2 +0 -0
  25. events.out.tfevents.1625634485.t1v-n-6586652e-w-0.68641.3.v2 +0 -0
  26. events.out.tfevents.1625640376.t1v-n-6586652e-w-0.77306.3.v2 +0 -0
  27. events.out.tfevents.1625640716.t1v-n-6586652e-w-0.79335.3.v2 +0 -0
  28. events.out.tfevents.1625644153.t1v-n-6586652e-w-0.84715.3.v2 +0 -0
  29. events.out.tfevents.1625644417.t1v-n-6586652e-w-0.86372.3.v2 +0 -0
  30. events.out.tfevents.1625644841.t1v-n-6586652e-w-0.88374.3.v2 +0 -0
  31. events.out.tfevents.1625652547.t1v-n-6586652e-w-0.96957.3.v2 +0 -0
  32. events.out.tfevents.1625652631.t1v-n-6586652e-w-0.98303.3.v2 +0 -0
  33. events.out.tfevents.1625652826.t1v-n-6586652e-w-0.99923.3.v2 +0 -0
  34. events.out.tfevents.1625653423.t1v-n-6586652e-w-0.102079.3.v2 +0 -0
  35. events.out.tfevents.1625653795.t1v-n-6586652e-w-0.103992.3.v2 +0 -0
  36. events.out.tfevents.1625653985.t1v-n-6586652e-w-0.105597.3.v2 +0 -0
  37. events.out.tfevents.1625673372.t1v-n-6586652e-w-0.130469.3.v2 +0 -0
  38. exact_match/exact_match.py +0 -47
  39. exact_match/exact_match.py.lock +0 -0
  40. run_summarization_flax.py +0 -823
.DS_Store ADDED
Binary file (6.15 kB). View file
README.md DELETED
@@ -1,55 +0,0 @@
1
- ---
2
- datasets:
3
- - wiki_split
4
-
5
- widget:
6
- - text: "Mary likes to play football in her freetime whenever she meets with her friends that are very nice people."
7
-
8
- license: mit
9
- ---
10
- # T5 model for sentence splitting in English
11
-
12
- Sentence Split is the task of dividing a long sentence into multiple sentences.
13
- E.g.:
14
- ```
15
- Mary likes to play football in her freetime whenever she meets with her friends that are very nice people.
16
- ```
17
- could be split into
18
- ```
19
- Mary likes to play football in her freetime whenever she meets with her friends.
20
- ```
21
- ```
22
- Her friends are very nice people.
23
- ```
24
-
25
- ## How to use it in your code:
26
- ```python
27
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
28
- tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-wikisplit")
29
- model = AutoModelForSeq2SeqLM.from_pretrained("flax-community/t5-base-wikisplit")
30
-
31
- complex_sentence = "This comedy drama is produced by Tidy , the company she co-founded in 2008 with her husband David Peet , who is managing director ."
32
- sample_tokenized = tokenizer(complex_sentence, return_tensors="pt")
33
-
34
- answer = model.generate(sample_tokenized['input_ids'], attention_mask = sample_tokenized['attention_mask'], max_length=256, num_beams=5)
35
- gene_sentence = tokenizer.decode(answer[0], skip_special_tokens=True)
36
- gene_sentence
37
-
38
- """
39
- Output:
40
- This comedy drama is produced by Tidy. She co-founded Tidy in 2008 with her husband David Peet, who is managing director.
41
- """
42
- ```
43
- ## Datasets:
44
- [Wiki_Split](https://research.google/tools/datasets/wiki-split/)
45
-
46
- ## Current Basline from [paper](https://arxiv.org/abs/1907.12461)
47
- ![baseline](./baseline.png)
48
-
49
- ## Our Results on Predict/Test set:
50
- | Model | Exact | SARI | BLEU |
51
- | --- | --- | --- | --- |
52
- | [t5-base-wikisplit](https://huggingface.co/flax-community/t5-base-wikisplit) | 17.93 | 67.5438 | 76.9 |
53
- | [t5-v1_1-base-wikisplit](https://huggingface.co/flax-community/t5-v1_1-base-wikisplit) | 18.1207 | 67.4873 | 76.9478 |
54
- | [byt5-base-wikisplit](https://huggingface.co/flax-community/byt5-base-wikisplit) | 11.3582 | 67.2685 | 73.1682 |
55
- | [t5-large-wikisplit](https://huggingface.co/flax-community/t5-large-wikisplit) | 18.6632 | 68.0501 | 77.1881 |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baseline.png DELETED
Binary file (49.9 kB)
convert_to_pytorch.py DELETED
@@ -1,3 +0,0 @@
1
- from transformers import AutoModelForSeq2SeqLM
2
- model = AutoModelForSeq2SeqLM.from_pretrained("./", from_flax=True)
3
- model.save_pretrained("./")
 
 
 
convert_to_tensorflow.py DELETED
@@ -1,3 +0,0 @@
1
- from transformers import TFAutoModelForSeq2SeqLM
2
- model = TFAutoModelForSeq2SeqLM.from_pretrained("./", from_pt=True)
3
- model.save_pretrained("./")
 
 
 
events.out.tfevents.1625592008.t1v-n-6586652e-w-0.376816.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625592046.t1v-n-6586652e-w-0.378250.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625592624.t1v-n-6586652e-w-0.380500.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625593313.t1v-n-6586652e-w-0.382633.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625626842.t1v-n-6586652e-w-0.33177.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625627167.t1v-n-6586652e-w-0.34916.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625629237.t1v-n-6586652e-w-0.38427.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625629547.t1v-n-6586652e-w-0.40115.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625629724.t1v-n-6586652e-w-0.41718.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625629983.t1v-n-6586652e-w-0.43435.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625630238.t1v-n-6586652e-w-0.45070.3.v2 DELETED
Binary file (862 Bytes)
events.out.tfevents.1625630522.t1v-n-6586652e-w-0.46805.3.v2 DELETED
Binary file (994 Bytes)
events.out.tfevents.1625630799.t1v-n-6586652e-w-0.48498.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625631065.t1v-n-6586652e-w-0.50228.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625631515.t1v-n-6586652e-w-0.52068.3.v2 DELETED
Binary file (1.13 kB)
events.out.tfevents.1625631789.t1v-n-6586652e-w-0.53785.3.v2 DELETED
Binary file (1.13 kB)
events.out.tfevents.1625633047.t1v-n-6586652e-w-0.63010.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625633391.t1v-n-6586652e-w-0.64789.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625633623.t1v-n-6586652e-w-0.66428.3.v2 DELETED
Binary file (1.13 kB)
events.out.tfevents.1625634485.t1v-n-6586652e-w-0.68641.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625640376.t1v-n-6586652e-w-0.77306.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625640716.t1v-n-6586652e-w-0.79335.3.v2 DELETED
Binary file (630 kB)
events.out.tfevents.1625644153.t1v-n-6586652e-w-0.84715.3.v2 DELETED
Binary file (1.13 kB)
events.out.tfevents.1625644417.t1v-n-6586652e-w-0.86372.3.v2 DELETED
Binary file (1.13 kB)
events.out.tfevents.1625644841.t1v-n-6586652e-w-0.88374.3.v2 DELETED
Binary file (1.13 kB)
events.out.tfevents.1625652547.t1v-n-6586652e-w-0.96957.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625652631.t1v-n-6586652e-w-0.98303.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625652826.t1v-n-6586652e-w-0.99923.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625653423.t1v-n-6586652e-w-0.102079.3.v2 DELETED
Binary file (902 Bytes)
events.out.tfevents.1625653795.t1v-n-6586652e-w-0.103992.3.v2 DELETED
Binary file (40 Bytes)
events.out.tfevents.1625653985.t1v-n-6586652e-w-0.105597.3.v2 DELETED
Binary file (2.34 MB)
events.out.tfevents.1625673372.t1v-n-6586652e-w-0.130469.3.v2 DELETED
Binary file (2.34 MB)
exact_match/exact_match.py DELETED
@@ -1,47 +0,0 @@
1
- import datasets
2
- import re
3
- import string
4
-
5
- def normalize_answer(s):
6
- """Lower text and remove punctuation, articles and extra whitespace."""
7
-
8
- def remove_articles(text):
9
- regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
10
- return re.sub(regex, " ", text)
11
-
12
- def white_space_fix(text):
13
- return " ".join(text.split())
14
-
15
- def remove_punc(text):
16
- exclude = set(string.punctuation)
17
- return "".join(ch for ch in text if ch not in exclude)
18
-
19
- def lower(text):
20
- return text.lower()
21
-
22
- return white_space_fix(remove_articles(remove_punc(lower(s))))
23
-
24
- def compute_exact(a_gold, a_pred):
25
- return int(normalize_answer(a_gold) == normalize_answer(a_pred))
26
-
27
- def compute_em(predictions, references):
28
- scores = [compute_exact(ref, pred) for pred, ref in zip(predictions, references)]
29
- return sum(scores)/len(scores)
30
-
31
- class ExactMatch(datasets.Metric):
32
- def _info(self):
33
- return datasets.MetricInfo(
34
- description="This will get effective exact match in text data",
35
- citation="",
36
- homepage="",
37
- inputs_description="",
38
- features=datasets.Features({
39
- 'predictions': datasets.Value('string'),
40
- 'references': datasets.Value('string'),
41
- }),
42
- codebase_urls=["https://github.com/huggingface/transformers/blob/master/src/transformers/data/metrics/squad_metrics.py"],
43
- reference_urls=["https://github.com/huggingface/transformers/blob/master/src/transformers/data/metrics/squad_metrics.py"]
44
- )
45
-
46
- def _compute(self, predictions, references):
47
- return {"exact_match": compute_em(predictions, references)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
exact_match/exact_match.py.lock DELETED
File without changes
run_summarization_flax.py DELETED
@@ -1,823 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2021 The HuggingFace Team All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- Fine-tuning the library models for summarization.
18
- """
19
- # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
-
21
- import logging
22
- import os
23
- import sys
24
- import time
25
- from dataclasses import dataclass, field
26
- from functools import partial
27
- from pathlib import Path
28
- from typing import Callable, Optional
29
-
30
- import datasets
31
- import nltk # Here to have a nice missing dependency error message early on
32
- import numpy as np
33
- from datasets import Dataset, load_dataset, load_metric
34
- from tqdm import tqdm
35
-
36
- import jax
37
- import jax.numpy as jnp
38
- import optax
39
- import transformers
40
- from filelock import FileLock
41
- from flax import jax_utils, traverse_util
42
- from flax.jax_utils import unreplicate
43
- from flax.training import train_state
44
- from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
- from transformers import (
46
- CONFIG_MAPPING,
47
- FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
48
- AutoConfig,
49
- AutoTokenizer,
50
- FlaxAutoModelForSeq2SeqLM,
51
- HfArgumentParser,
52
- TrainingArguments,
53
- is_tensorboard_available,
54
- )
55
- from transformers.file_utils import is_offline_mode
56
-
57
-
58
- logger = logging.getLogger(__name__)
59
-
60
- try:
61
- nltk.data.find("tokenizers/punkt")
62
- except (LookupError, OSError):
63
- if is_offline_mode():
64
- raise LookupError(
65
- "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
66
- )
67
- with FileLock(".lock") as lock:
68
- nltk.download("punkt", quiet=True)
69
-
70
-
71
- MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
72
- MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
73
-
74
-
75
- @dataclass
76
- class ModelArguments:
77
- """
78
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
79
- """
80
-
81
- model_name_or_path: Optional[str] = field(
82
- default=None,
83
- metadata={
84
- "help": "The model checkpoint for weights initialization."
85
- "Don't set if you want to train a model from scratch."
86
- },
87
- )
88
- model_type: Optional[str] = field(
89
- default=None,
90
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
91
- )
92
- config_name: Optional[str] = field(
93
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
94
- )
95
- tokenizer_name: Optional[str] = field(
96
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
97
- )
98
- cache_dir: Optional[str] = field(
99
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
100
- )
101
- use_fast_tokenizer: bool = field(
102
- default=True,
103
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
104
- )
105
- dtype: Optional[str] = field(
106
- default="float32",
107
- metadata={
108
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
109
- },
110
- )
111
-
112
-
113
- @dataclass
114
- class DataTrainingArguments:
115
- """
116
- Arguments pertaining to what data we are going to input our model for training and eval.
117
- """
118
-
119
- dataset_name: Optional[str] = field(
120
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
121
- )
122
- dataset_config_name: Optional[str] = field(
123
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
124
- )
125
- text_column: Optional[str] = field(
126
- default=None,
127
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
128
- )
129
- summary_column: Optional[str] = field(
130
- default=None,
131
- metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
132
- )
133
- train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
134
- validation_file: Optional[str] = field(
135
- default=None,
136
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
137
- )
138
- test_file: Optional[str] = field(
139
- default=None,
140
- metadata={"help": "An optional input prediction data file to evaluate the perplexity on (a text file)."},
141
- )
142
- max_source_length: Optional[int] = field(
143
- default=1024,
144
- metadata={
145
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
146
- "than this will be truncated, sequences shorter will be padded."
147
- },
148
- )
149
- max_target_length: Optional[int] = field(
150
- default=128,
151
- metadata={
152
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
153
- "than this will be truncated, sequences shorter will be padded."
154
- },
155
- )
156
- val_max_target_length: Optional[int] = field(
157
- default=None,
158
- metadata={
159
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
160
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
161
- "This argument is also used to override the `max_length` param of `model.generate`, which is used "
162
- "during evaluation."
163
- },
164
- )
165
- max_train_samples: Optional[int] = field(
166
- default=None,
167
- metadata={
168
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
169
- "value if set."
170
- },
171
- )
172
- max_eval_samples: Optional[int] = field(
173
- default=None,
174
- metadata={
175
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
176
- "value if set."
177
- },
178
- )
179
- max_predict_samples: Optional[int] = field(
180
- default=None,
181
- metadata={
182
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
183
- "value if set."
184
- },
185
- )
186
- preprocessing_num_workers: Optional[int] = field(
187
- default=None,
188
- metadata={"help": "The number of processes to use for the preprocessing."},
189
- )
190
- source_prefix: Optional[str] = field(
191
- default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
192
- )
193
- predict_with_generate: bool = field(
194
- default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
195
- )
196
- num_beams: Optional[int] = field(
197
- default=None,
198
- metadata={
199
- "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
200
- "which is used during evaluation."
201
- },
202
- )
203
- overwrite_cache: bool = field(
204
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
205
- )
206
-
207
- def __post_init__(self):
208
- if self.dataset_name is None and self.train_file is None and self.validation_file is None:
209
- raise ValueError("Need either a dataset name or a training/validation file.")
210
- else:
211
- if self.train_file is not None:
212
- extension = self.train_file.split(".")[-1]
213
- assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
214
- if self.validation_file is not None:
215
- extension = self.validation_file.split(".")[-1]
216
- assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
217
- if self.val_max_target_length is None:
218
- self.val_max_target_length = self.max_target_length
219
-
220
-
221
- summarization_name_mapping = {
222
- "amazon_reviews_multi": ("review_body", "review_title"),
223
- "big_patent": ("description", "abstract"),
224
- "cnn_dailymail": ("article", "highlights"),
225
- "orange_sum": ("text", "summary"),
226
- "pn_summary": ("article", "summary"),
227
- "psc": ("extract_text", "summary_text"),
228
- "samsum": ("dialogue", "summary"),
229
- "thaisum": ("body", "summary"),
230
- "xglue": ("news_body", "news_title"),
231
- "xsum": ("document", "summary"),
232
- "wiki_summary": ("article", "highlights"),
233
- }
234
-
235
-
236
- class TrainState(train_state.TrainState):
237
- dropout_rng: jnp.ndarray
238
-
239
- def replicate(self):
240
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
241
-
242
-
243
- def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
244
- """
245
- Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
246
- Shuffle batches if `shuffle` is `True`.
247
- """
248
- steps_per_epoch = len(dataset) // batch_size
249
-
250
- if shuffle:
251
- batch_idx = jax.random.permutation(rng, len(dataset))
252
- else:
253
- batch_idx = jnp.arange(len(dataset))
254
-
255
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
256
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
257
-
258
- for idx in batch_idx:
259
- batch = dataset[idx]
260
- batch = {k: jnp.array(v) for k, v in batch.items()}
261
-
262
- batch = shard(batch)
263
-
264
- yield batch
265
-
266
-
267
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
268
- summary_writer.scalar("train_time", train_time, step)
269
-
270
- train_metrics = get_metrics(train_metrics)
271
- for key, vals in train_metrics.items():
272
- tag = f"train_{key}"
273
- for i, val in enumerate(vals):
274
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
275
-
276
- for metric_name, value in eval_metrics.items():
277
- summary_writer.scalar(f"eval_{metric_name}", value, step)
278
-
279
-
280
- def create_learning_rate_fn(
281
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
282
- ) -> Callable[[int], jnp.array]:
283
- """Returns a linear warmup, linear_decay learning rate function."""
284
- steps_per_epoch = train_ds_size // train_batch_size
285
- num_train_steps = steps_per_epoch * num_train_epochs
286
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
287
- decay_fn = optax.linear_schedule(
288
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
289
- )
290
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
291
- return schedule_fn
292
-
293
-
294
- def main():
295
- # See all possible arguments in src/transformers/training_args.py
296
- # or by passing the --help flag to this script.
297
- # We now keep distinct sets of args, for a cleaner separation of concerns.
298
-
299
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
300
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
301
- # If we pass only one argument to the script and it's the path to a json file,
302
- # let's parse it to get our arguments.
303
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
304
- else:
305
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
306
-
307
- if (
308
- os.path.exists(training_args.output_dir)
309
- and os.listdir(training_args.output_dir)
310
- and training_args.do_train
311
- and not training_args.overwrite_output_dir
312
- ):
313
- raise ValueError(
314
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
315
- "Use --overwrite_output_dir to overcome."
316
- )
317
-
318
- # Make one log on every process with the configuration for debugging.
319
- logging.basicConfig(
320
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
321
- datefmt="%m/%d/%Y %H:%M:%S",
322
- level=logging.INFO,
323
- )
324
- # Setup logging, we only want one process per machine to log things on the screen.
325
- logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
326
- if jax.process_index() == 0:
327
- datasets.utils.logging.set_verbosity_warning()
328
- transformers.utils.logging.set_verbosity_info()
329
- else:
330
- datasets.utils.logging.set_verbosity_error()
331
- transformers.utils.logging.set_verbosity_error()
332
-
333
- # Set the verbosity to info of the Transformers logger (on main process only):
334
- logger.info(f"Training/evaluation parameters {training_args}")
335
-
336
- # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
337
- # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
338
- # (the dataset will be downloaded automatically from the datasets Hub).
339
- #
340
- # For CSV/JSON files this script will use the first column for the full texts and the second column for the
341
- # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
342
- #
343
- if data_args.dataset_name is not None:
344
- # Downloading and loading a dataset from the hub.
345
- dataset = load_dataset(
346
- data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
347
- )
348
- else:
349
- data_files = {}
350
- if data_args.train_file is not None:
351
- data_files["train"] = data_args.train_file
352
- extension = data_args.train_file.split(".")[-1]
353
- if data_args.validation_file is not None:
354
- data_files["validation"] = data_args.validation_file
355
- extension = data_args.validation_file.split(".")[-1]
356
- if data_args.test_file is not None:
357
- data_files["test"] = data_args.test_file
358
- extension = data_args.test_file.split(".")[-1]
359
- dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
360
- # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
361
- # https://huggingface.co/docs/datasets/loading_datasets.html.
362
-
363
- # Load pretrained model and tokenizer
364
-
365
- if model_args.config_name:
366
- config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
367
- elif model_args.model_name_or_path:
368
- config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
369
- else:
370
- config = CONFIG_MAPPING[model_args.model_type]()
371
- logger.warning("You are instantiating a new config instance from scratch.")
372
-
373
- if model_args.tokenizer_name:
374
- tokenizer = AutoTokenizer.from_pretrained(
375
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
376
- )
377
- elif model_args.model_name_or_path:
378
- tokenizer = AutoTokenizer.from_pretrained(
379
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
380
- )
381
- else:
382
- raise ValueError(
383
- "You are instantiating a new tokenizer from scratch. This is not supported by this script."
384
- "You can do it from another script, save it, and load it from here, using --tokenizer_name."
385
- )
386
-
387
- if model_args.model_name_or_path:
388
- model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
389
- model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
390
- )
391
- else:
392
- model = FlaxAutoModelForSeq2SeqLM.from_config(
393
- config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
394
- )
395
-
396
- if model.config.decoder_start_token_id is None:
397
- raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
398
-
399
- prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
400
-
401
- # Preprocessing the datasets.
402
- # We need to tokenize inputs and targets.
403
- if training_args.do_train:
404
- column_names = dataset["train"].column_names
405
- elif training_args.do_eval:
406
- column_names = dataset["validation"].column_names
407
- elif training_args.do_predict:
408
- column_names = dataset["test"].column_names
409
- else:
410
- logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
411
- return
412
-
413
- # Get the column names for input/target.
414
- dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
415
- if data_args.text_column is None:
416
- text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
417
- else:
418
- text_column = data_args.text_column
419
- if text_column not in column_names:
420
- raise ValueError(
421
- f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
422
- )
423
- if data_args.summary_column is None:
424
- summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
425
- else:
426
- summary_column = data_args.summary_column
427
- if summary_column not in column_names:
428
- raise ValueError(
429
- f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
430
- )
431
-
432
- # Temporarily set max_target_length for training.
433
- max_target_length = data_args.max_target_length
434
-
435
- # In Flax, for seq2seq models we need to pass `decoder_input_ids`
436
- # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
437
- # for that dynamically import the `shift_tokens_right` function from the model file
438
- model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
439
- shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
440
-
441
- # Setting padding="max_length" as we need fixed length inputs for jitted functions
442
- def preprocess_function(examples):
443
- inputs = examples[text_column]
444
- targets = examples[summary_column]
445
- inputs = [prefix + inp for inp in inputs]
446
- model_inputs = tokenizer(
447
- inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
448
- )
449
-
450
- # Setup the tokenizer for targets
451
- with tokenizer.as_target_tokenizer():
452
- labels = tokenizer(
453
- targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
454
- )
455
-
456
- model_inputs["labels"] = labels["input_ids"]
457
- decoder_input_ids = shift_tokens_right_fn(
458
- jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
459
- )
460
- model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
461
-
462
- # We need decoder_attention_mask so we can ignore pad tokens from loss
463
- model_inputs["decoder_attention_mask"] = labels["attention_mask"]
464
-
465
- return model_inputs
466
-
467
- if training_args.do_train:
468
- if "train" not in dataset:
469
- raise ValueError("--do_train requires a train dataset")
470
- train_dataset = dataset["train"]
471
- if data_args.max_train_samples is not None:
472
- train_dataset = train_dataset.select(range(data_args.max_train_samples))
473
- train_dataset = train_dataset.map(
474
- preprocess_function,
475
- batched=True,
476
- num_proc=data_args.preprocessing_num_workers,
477
- remove_columns=column_names,
478
- load_from_cache_file=not data_args.overwrite_cache,
479
- desc="Running tokenizer on train dataset",
480
- )
481
-
482
- if training_args.do_eval:
483
- max_target_length = data_args.val_max_target_length
484
- if "validation" not in dataset:
485
- raise ValueError("--do_eval requires a validation dataset")
486
- eval_dataset = dataset["validation"]
487
- if data_args.max_eval_samples is not None:
488
- eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
489
- eval_dataset = eval_dataset.map(
490
- preprocess_function,
491
- batched=True,
492
- num_proc=data_args.preprocessing_num_workers,
493
- remove_columns=column_names,
494
- load_from_cache_file=not data_args.overwrite_cache,
495
- desc="Running tokenizer on validation dataset",
496
- )
497
-
498
- if training_args.do_predict:
499
- max_target_length = data_args.val_max_target_length
500
- if "test" not in dataset:
501
- raise ValueError("--do_predict requires a test dataset")
502
- predict_dataset = dataset["test"]
503
- if data_args.max_predict_samples is not None:
504
- predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
505
- predict_dataset = predict_dataset.map(
506
- preprocess_function,
507
- batched=True,
508
- num_proc=data_args.preprocessing_num_workers,
509
- remove_columns=column_names,
510
- load_from_cache_file=not data_args.overwrite_cache,
511
- desc="Running tokenizer on prediction dataset",
512
- )
513
-
514
- # Metric
515
- sacrebleu = load_metric("sacrebleu")
516
- sari = load_metric("sari")
517
- em = load_metric("/home/bhadresh/transformers/examples/flax/summarization/exact_match")
518
-
519
- def postprocess_text(preds, labels, sources):
520
- preds = [pred.strip() for pred in preds]
521
- sources = [source.strip() for source in sources]
522
- labels = [[label.strip()] for label in labels]
523
- pure_labels = [label[0] for label in labels]
524
- return preds, labels, pure_labels, sources
525
-
526
- def compute_metrics(sources, preds, labels):
527
- decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
528
- decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
529
- decoded_src = tokenizer.batch_decode(sources, skip_special_tokens=True)
530
-
531
- # Some simple post-processing
532
- decoded_preds, decoded_labels, pure_decoded_labels, decoded_src = postprocess_text(decoded_preds, decoded_labels, decoded_src)
533
- print(len(decoded_preds))
534
- print(len(decoded_labels))
535
- print(len(pure_decoded_labels))
536
- print(len(decoded_preds))
537
- sacrebleu_result = sacrebleu.compute(predictions=decoded_preds, references=decoded_labels)
538
- sari_result = sari.compute(sources=decoded_src, predictions=decoded_preds, references=decoded_labels)
539
- exact_result = em.compute(predictions=decoded_preds, references=pure_decoded_labels)
540
-
541
- result = {
542
- "bleu": sacrebleu_result["score"],
543
- "sari": sari_result['sari'],
544
- "exact": exact_result['exact_match']
545
- }
546
-
547
- prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
548
- result["gen_len"] = np.mean(prediction_lens)
549
- result = {k: round(v, 4) for k, v in result.items()}
550
- return result
551
-
552
- # Enable tensorboard only on the master node
553
- has_tensorboard = is_tensorboard_available()
554
- if has_tensorboard and jax.process_index() == 0:
555
- try:
556
- from flax.metrics.tensorboard import SummaryWriter
557
-
558
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
559
- except ImportError as ie:
560
- has_tensorboard = False
561
- logger.warning(
562
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
563
- )
564
- else:
565
- logger.warning(
566
- "Unable to display metrics through TensorBoard because the package is not installed: "
567
- "Please run pip install tensorboard to enable."
568
- )
569
-
570
- # Initialize our training
571
- rng = jax.random.PRNGKey(training_args.seed)
572
- rng, dropout_rng = jax.random.split(rng)
573
-
574
- # Store some constant
575
- num_epochs = int(training_args.num_train_epochs)
576
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
577
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
578
- steps_per_epoch = len(train_dataset) // train_batch_size
579
- total_train_steps = steps_per_epoch * num_epochs
580
-
581
- # Create learning rate schedule
582
- linear_decay_lr_schedule_fn = create_learning_rate_fn(
583
- len(train_dataset),
584
- train_batch_size,
585
- training_args.num_train_epochs,
586
- training_args.warmup_steps,
587
- training_args.learning_rate,
588
- )
589
-
590
- # We use Optax's "masking" functionality to not apply weight decay
591
- # to bias and LayerNorm scale parameters. decay_mask_fn returns a
592
- # mask boolean with the same structure as the parameters.
593
- # The mask is True for parameters that should be decayed.
594
- # Note that this mask is specifically adapted for FlaxBart.
595
- # For FlaxT5, one should correct the layer norm parameter naming
596
- # accordingly - see `run_t5_mlm_flax.py` e.g.
597
- def decay_mask_fn(params):
598
- flat_params = traverse_util.flatten_dict(params)
599
- layer_norm_params = [
600
- (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
601
- ]
602
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
603
- return traverse_util.unflatten_dict(flat_mask)
604
-
605
- # create adam optimizer
606
- adamw = optax.adamw(
607
- learning_rate=linear_decay_lr_schedule_fn,
608
- b1=training_args.adam_beta1,
609
- b2=training_args.adam_beta2,
610
- eps=training_args.adam_epsilon,
611
- weight_decay=training_args.weight_decay,
612
- mask=decay_mask_fn,
613
- )
614
-
615
- # Setup train state
616
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
617
-
618
- # label smoothed cross entropy
619
- def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
620
- """
621
- The label smoothing implementation is adapted from Flax's official example:
622
- https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
623
- """
624
- vocab_size = logits.shape[-1]
625
- confidence = 1.0 - label_smoothing_factor
626
- low_confidence = (1.0 - confidence) / (vocab_size - 1)
627
- normalizing_constant = -(
628
- confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
629
- )
630
- soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
631
-
632
- loss = optax.softmax_cross_entropy(logits, soft_labels)
633
- loss = loss - normalizing_constant
634
-
635
- # ignore padded tokens from loss
636
- loss = loss * padding_mask
637
- loss = loss.sum() / padding_mask.sum()
638
- return loss
639
-
640
- # Define gradient update step fn
641
- def train_step(state, batch, label_smoothing_factor=0.0):
642
- dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
643
-
644
- def compute_loss(params):
645
- labels = batch.pop("labels")
646
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
647
- loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
648
- return loss
649
-
650
- grad_fn = jax.value_and_grad(compute_loss)
651
- loss, grad = grad_fn(state.params)
652
- grad = jax.lax.pmean(grad, "batch")
653
-
654
- new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
655
-
656
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
657
- metrics = jax.lax.pmean(metrics, axis_name="batch")
658
-
659
- return new_state, metrics
660
-
661
- # Define eval fn
662
- def eval_step(params, batch, label_smoothing_factor=0.0):
663
- labels = batch.pop("labels")
664
- logits = model(**batch, params=params, train=False)[0]
665
- loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
666
-
667
- # summarize metrics
668
- metrics = {"loss": loss}
669
- metrics = jax.lax.pmean(metrics, axis_name="batch")
670
- return metrics
671
-
672
- # Define generation function
673
- max_length = (
674
- data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
675
- )
676
- num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
677
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
678
-
679
- def generate_step(params, batch):
680
- model.params = params
681
- output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
682
- return output_ids.sequences
683
-
684
- # Create parallel version of the train and eval step
685
- p_train_step = jax.pmap(
686
- partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
687
- )
688
- p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
689
- p_generate_step = jax.pmap(generate_step, "batch")
690
-
691
- # Replicate the train state on each device
692
- state = state.replicate()
693
-
694
- logger.info("***** Running training *****")
695
- logger.info(f" Num examples = {len(train_dataset)}")
696
- logger.info(f" Num Epochs = {num_epochs}")
697
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
698
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
699
- logger.info(f" Total optimization steps = {total_train_steps}")
700
-
701
- train_time = 0
702
- epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
703
- for epoch in epochs:
704
- # ======================== Training ================================
705
- train_start = time.time()
706
-
707
- # Create sampling rng
708
- rng, input_rng = jax.random.split(rng)
709
- train_metrics = []
710
-
711
- # Generate an epoch by shuffling sampling indices from the train dataset
712
- train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
713
- steps_per_epoch = len(train_dataset) // train_batch_size
714
- # train
715
- for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
716
- batch = next(train_loader)
717
- state, train_metric = p_train_step(state, batch)
718
- train_metrics.append(train_metric)
719
-
720
- train_time += time.time() - train_start
721
-
722
- train_metric = unreplicate(train_metric)
723
-
724
- epochs.write(
725
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
726
- )
727
-
728
- # ======================== Evaluating ==============================
729
- eval_metrics = []
730
- eval_preds = []
731
- eval_labels = []
732
- eval_sources = []
733
-
734
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
735
- eval_steps = len(eval_dataset) // eval_batch_size
736
- for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
737
- # Model forward
738
- batch = next(eval_loader)
739
- labels = batch["labels"]
740
-
741
- metrics = p_eval_step(state.params, batch)
742
- eval_metrics.append(metrics)
743
-
744
- # generation
745
- if data_args.predict_with_generate:
746
- generated_ids = p_generate_step(state.params, batch)
747
- eval_sources.extend(jax.device_get(batch['input_ids'].reshape(-1, batch['input_ids'].shape[-1])))
748
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
749
- eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
750
-
751
- # normalize eval metrics
752
- eval_metrics = get_metrics(eval_metrics)
753
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
754
-
755
- # compute ROUGE metrics
756
- rouge_desc = ""
757
- if data_args.predict_with_generate:
758
- rouge_metrics = compute_metrics(eval_sources, eval_preds, eval_labels)
759
- eval_metrics.update(rouge_metrics)
760
- rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
761
-
762
- # Print metrics and update progress bar
763
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
764
- epochs.write(desc)
765
- epochs.desc = desc
766
-
767
- # Save metrics
768
- if has_tensorboard and jax.process_index() == 0:
769
- cur_step = epoch * (len(train_dataset) // train_batch_size)
770
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
771
-
772
- # ======================== Prediction loop ==============================
773
- if training_args.do_predict:
774
- logger.info("*** Predict ***")
775
-
776
- pred_metrics = []
777
- pred_generations = []
778
- pred_labels = []
779
-
780
- pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
781
- pred_steps = len(predict_dataset) // eval_batch_size
782
- for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
783
- # Model forward
784
- batch = next(pred_loader)
785
- labels = batch["labels"]
786
-
787
- metrics = p_eval_step(state.params, batch)
788
- pred_metrics.append(metrics)
789
-
790
- # generation
791
- if data_args.predict_with_generate:
792
- generated_ids = p_generate_step(state.params, batch)
793
- pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
794
- pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
795
-
796
- # normalize prediction metrics
797
- pred_metrics = get_metrics(pred_metrics)
798
- pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
799
-
800
- # compute ROUGE metrics
801
- rouge_desc = ""
802
- if data_args.predict_with_generate:
803
- rouge_metrics = compute_metrics(pred_generations, pred_labels)
804
- pred_metrics.update(rouge_metrics)
805
- rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
806
-
807
- # Print metrics
808
- desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
809
- logger.info(desc)
810
-
811
- # save checkpoint after each epoch and push checkpoint to the hub
812
- if jax.process_index() == 0:
813
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
814
- model.save_pretrained(
815
- training_args.output_dir,
816
- params=params,
817
- push_to_hub=training_args.push_to_hub,
818
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
819
- )
820
-
821
-
822
- if __name__ == "__main__":
823
- main()