patrickvonplaten
commited on
Commit
•
51b5d5c
1
Parent(s):
c83632a
Update README.md
Browse files
README.md
CHANGED
@@ -42,174 +42,8 @@ ths.
|
|
42 |
|
43 |
## Training script:
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
the `Trainer` for `EncoderDecoderModels` according to this PR: https://github.com/huggingface/transformers/pull/5840.
|
48 |
-
|
49 |
-
The following code shows the complete training script that was used to fine-tune `bert2bert-cnn_dailymail-fp16
|
50 |
-
` for reproducability. The training last ~9h on a standard GPU.
|
51 |
-
|
52 |
-
```python
|
53 |
-
#!/usr/bin/env python3
|
54 |
-
import nlp
|
55 |
-
import logging
|
56 |
-
from transformers import BertTokenizer, EncoderDecoderModel, Trainer, TrainingArguments
|
57 |
-
|
58 |
-
logging.basicConfig(level=logging.INFO)
|
59 |
-
|
60 |
-
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
61 |
-
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
62 |
-
|
63 |
-
# CLS token will work as BOS token
|
64 |
-
tokenizer.bos_token = tokenizer.cls_token
|
65 |
-
|
66 |
-
# SEP token will work as EOS token
|
67 |
-
tokenizer.eos_token = tokenizer.sep_token
|
68 |
-
|
69 |
-
# load train and validation data
|
70 |
-
train_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="train")
|
71 |
-
val_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]")
|
72 |
-
|
73 |
-
# load rouge for validation
|
74 |
-
rouge = nlp.load_metric("rouge")
|
75 |
-
|
76 |
-
|
77 |
-
# set decoding params
|
78 |
-
model.config.decoder_start_token_id = tokenizer.bos_token_id
|
79 |
-
model.config.eos_token_id = tokenizer.eos_token_id
|
80 |
-
model.config.max_length = 142
|
81 |
-
model.config.min_length = 56
|
82 |
-
model.config.no_repeat_ngram_size = 3
|
83 |
-
model.early_stopping = True
|
84 |
-
model.length_penalty = 2.0
|
85 |
-
model.num_beams = 4
|
86 |
-
|
87 |
-
|
88 |
-
# map data correctly
|
89 |
-
def map_to_encoder_decoder_inputs(batch):
|
90 |
-
# Tokenizer will automatically set [BOS] <text> [EOS]
|
91 |
-
# cut off at BERT max length 512
|
92 |
-
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
|
93 |
-
# force summarization <= 128
|
94 |
-
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
|
95 |
-
|
96 |
-
batch["input_ids"] = inputs.input_ids
|
97 |
-
batch["attention_mask"] = inputs.attention_mask
|
98 |
-
|
99 |
-
batch["decoder_input_ids"] = outputs.input_ids
|
100 |
-
batch["labels"] = outputs.input_ids.copy()
|
101 |
-
# mask loss for padding
|
102 |
-
batch["labels"] = [
|
103 |
-
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
|
104 |
-
]
|
105 |
-
batch["decoder_attention_mask"] = outputs.attention_mask
|
106 |
-
|
107 |
-
assert all([len(x) == 512 for x in inputs.input_ids])
|
108 |
-
assert all([len(x) == 128 for x in outputs.input_ids])
|
109 |
-
|
110 |
-
return batch
|
111 |
-
|
112 |
-
|
113 |
-
def compute_metrics(pred):
|
114 |
-
labels_ids = pred.label_ids
|
115 |
-
pred_ids = pred.predictions
|
116 |
-
|
117 |
-
# all unnecessary tokens are removed
|
118 |
-
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
119 |
-
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
120 |
-
|
121 |
-
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
|
122 |
-
|
123 |
-
return {
|
124 |
-
"rouge2_precision": round(rouge_output.precision, 4),
|
125 |
-
"rouge2_recall": round(rouge_output.recall, 4),
|
126 |
-
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
|
127 |
-
}
|
128 |
-
|
129 |
-
|
130 |
-
# set batch size here
|
131 |
-
batch_size = 16
|
132 |
-
|
133 |
-
# make train dataset ready
|
134 |
-
train_dataset = train_dataset.map(
|
135 |
-
map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["article", "highlights"],
|
136 |
-
)
|
137 |
-
train_dataset.set_format(
|
138 |
-
type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
139 |
-
)
|
140 |
-
|
141 |
-
# same for validation dataset
|
142 |
-
val_dataset = val_dataset.map(
|
143 |
-
map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["article", "highlights"],
|
144 |
-
)
|
145 |
-
val_dataset.set_format(
|
146 |
-
type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
147 |
-
)
|
148 |
-
|
149 |
-
# set training arguments - these params are not really tuned, feel free to change
|
150 |
-
training_args = TrainingArguments(
|
151 |
-
output_dir="./",
|
152 |
-
per_device_train_batch_size=batch_size,
|
153 |
-
per_device_eval_batch_size=batch_size,
|
154 |
-
predict_from_generate=True,
|
155 |
-
evaluate_during_training=True,
|
156 |
-
do_train=True,
|
157 |
-
do_eval=True,
|
158 |
-
logging_steps=1000,
|
159 |
-
save_steps=1000,
|
160 |
-
eval_steps=1000,
|
161 |
-
overwrite_output_dir=True,
|
162 |
-
warmup_steps=2000,
|
163 |
-
save_total_limit=10,
|
164 |
-
)
|
165 |
-
|
166 |
-
# instantiate trainer
|
167 |
-
trainer = Trainer(
|
168 |
-
model=model,
|
169 |
-
args=training_args,
|
170 |
-
compute_metrics=compute_metrics,
|
171 |
-
train_dataset=train_dataset,
|
172 |
-
eval_dataset=val_dataset,
|
173 |
-
)
|
174 |
-
|
175 |
-
# start training
|
176 |
-
trainer.train()
|
177 |
-
```
|
178 |
-
|
179 |
-
## Evaluation
|
180 |
-
|
181 |
-
The following script evaluates the model on the test set of
|
182 |
-
CNN/Daily Mail.
|
183 |
-
|
184 |
-
```python
|
185 |
-
#!/usr/bin/env python3
|
186 |
-
import nlp
|
187 |
-
from transformers import BertTokenizer, EncoderDecoderModel
|
188 |
-
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
189 |
-
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
190 |
-
model.to("cuda")
|
191 |
-
test_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="test")
|
192 |
-
batch_size = 128
|
193 |
-
# map data correctly
|
194 |
-
def generate_summary(batch):
|
195 |
-
# Tokenizer will automatically set [BOS] <text> [EOS]
|
196 |
-
# cut off at BERT max length 512
|
197 |
-
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
|
198 |
-
input_ids = inputs.input_ids.to("cuda")
|
199 |
-
attention_mask = inputs.attention_mask.to("cuda")
|
200 |
-
outputs = model.generate(input_ids, attention_mask=attention_mask)
|
201 |
-
# all special tokens including will be removed
|
202 |
-
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
203 |
-
batch["pred"] = output_str
|
204 |
-
return batch
|
205 |
-
results = test_dataset.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])
|
206 |
-
# load rouge for validation
|
207 |
-
rouge = nlp.load_metric("rouge")
|
208 |
-
pred_str = results["pred"]
|
209 |
-
label_str = results["highlights"]
|
210 |
-
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
|
211 |
-
print(rouge_output)
|
212 |
-
```
|
213 |
|
214 |
The obtained results should be:
|
215 |
|
|
|
42 |
|
43 |
## Training script:
|
44 |
|
45 |
+
Please follow this tutorial to see how to warm-start a BERT2BERT model:
|
46 |
+
https://colab.research.google.com/drive/1WIk2bxglElfZewOHboPFNj8H44_VAyKE?usp=sharing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
The obtained results should be:
|
49 |
|