add train script
Browse files- led-finetune-lfqa-train.py +181 -0
led-finetune-lfqa-train.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from typing import Optional
|
3 |
+
from datasets import load_dataset
|
4 |
+
import evaluate
|
5 |
+
from transformers import (
|
6 |
+
Seq2SeqTrainer,
|
7 |
+
Seq2SeqTrainingArguments,
|
8 |
+
AutoTokenizer,
|
9 |
+
AutoModelForSeq2SeqLM,
|
10 |
+
GenerationConfig,
|
11 |
+
)
|
12 |
+
import wandb
|
13 |
+
|
14 |
+
|
15 |
+
# See https://huggingface.co/docs/transformers/en/perf_train_gpu_one
|
16 |
+
BATCH_SIZE: int = 2
|
17 |
+
# Max allowed answer length in tokens --> select corresponding processed dataset and set allowed decoder len
|
18 |
+
MAX_ANSWER_LENGTH: int = 512
|
19 |
+
|
20 |
+
# initialize wandb for monitoring
|
21 |
+
run_name: str = f"vast-gpu_batch-size-{BATCH_SIZE}_ans-len-{MAX_ANSWER_LENGTH}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
|
22 |
+
wandb.init(project="led-finetune-lfqa", name=run_name) # type: ignore
|
23 |
+
|
24 |
+
# load rouge for evaluation
|
25 |
+
rouge = evaluate.load("rouge")
|
26 |
+
|
27 |
+
# load model and tokenizer
|
28 |
+
# larger model for better performance, but slower to train: allenai/led-large-16384
|
29 |
+
pretrained_model_name = "allenai/led-base-16384"
|
30 |
+
my_model_name = f"stefanbschneider/led-base-16384-lfqa-ans-len-{MAX_ANSWER_LENGTH}"
|
31 |
+
model_name = my_model_name
|
32 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
33 |
+
# Enable gradient checkpointing to reduce memory during training (at the cost of speed)
|
34 |
+
model.gradient_checkpointing_enable()
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
36 |
+
|
37 |
+
|
38 |
+
def process_data_to_model_inputs(batch):
|
39 |
+
# combine context strings and questions to one input
|
40 |
+
input = [
|
41 |
+
f"question: {question}, context: {' '.join(context)}"
|
42 |
+
for question, context in zip(batch["question"], batch["context"])
|
43 |
+
]
|
44 |
+
|
45 |
+
# tokenize the inputs and labels
|
46 |
+
inputs = tokenizer(
|
47 |
+
input,
|
48 |
+
padding="max_length",
|
49 |
+
truncation=True,
|
50 |
+
# Max supported article/context length + question.
|
51 |
+
max_length=8192,
|
52 |
+
)
|
53 |
+
outputs = tokenizer(
|
54 |
+
batch["answer"],
|
55 |
+
padding="max_length",
|
56 |
+
truncation=True,
|
57 |
+
# Answers in the dataset should be limited to MAX_ANSWER_LENGTH already (see my dataset description)
|
58 |
+
max_length=MAX_ANSWER_LENGTH,
|
59 |
+
)
|
60 |
+
|
61 |
+
batch["input_ids"] = inputs.input_ids
|
62 |
+
batch["attention_mask"] = inputs.attention_mask
|
63 |
+
|
64 |
+
# create 0 global_attention_mask lists
|
65 |
+
batch["global_attention_mask"] = len(batch["input_ids"]) * [
|
66 |
+
[0 for _ in range(len(batch["input_ids"][0]))]
|
67 |
+
]
|
68 |
+
|
69 |
+
# since above lists are references, the following line changes the 0 index for all samples
|
70 |
+
batch["global_attention_mask"][0][0] = 1
|
71 |
+
batch["labels"] = outputs.input_ids
|
72 |
+
|
73 |
+
# We have to make sure that the PAD token is ignored
|
74 |
+
batch["labels"] = [
|
75 |
+
[-100 if token == tokenizer.pad_token_id else token for token in labels]
|
76 |
+
for labels in batch["labels"]
|
77 |
+
]
|
78 |
+
|
79 |
+
return batch
|
80 |
+
|
81 |
+
|
82 |
+
def load_and_process_dataset(split: str, dataset_limit: Optional[int] = None):
|
83 |
+
"""Load and process the dataset for training or validation. Optionally limit the number of samples."""
|
84 |
+
dataset = load_dataset(f"stefanbschneider/lfqa-max-answer-length-{MAX_ANSWER_LENGTH}", split=split)
|
85 |
+
|
86 |
+
# optionally reduce the data sets to a small fraction
|
87 |
+
if dataset_limit is not None:
|
88 |
+
dataset = dataset.select(range(dataset_limit))
|
89 |
+
|
90 |
+
dataset = dataset.map(
|
91 |
+
process_data_to_model_inputs,
|
92 |
+
batched=True,
|
93 |
+
batch_size=BATCH_SIZE,
|
94 |
+
remove_columns=["context", "question", "answer"],
|
95 |
+
)
|
96 |
+
|
97 |
+
dataset.set_format(
|
98 |
+
type="torch",
|
99 |
+
columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
|
100 |
+
)
|
101 |
+
|
102 |
+
return dataset
|
103 |
+
|
104 |
+
|
105 |
+
def compute_metrics(pred) -> dict[str, float]:
|
106 |
+
"""Compute rouge score during validation/evaluation"""
|
107 |
+
labels_ids = pred.label_ids
|
108 |
+
pred_ids = pred.predictions
|
109 |
+
|
110 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
111 |
+
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
|
112 |
+
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
113 |
+
|
114 |
+
rouge_output = rouge.compute(
|
115 |
+
predictions=pred_str, references=label_str, rouge_types=["rouge2"]
|
116 |
+
)["rouge2"]
|
117 |
+
|
118 |
+
# Return rouge2 F1 score
|
119 |
+
# There are no longer separate precisoin, recall values in rouge2
|
120 |
+
return {"rouge2": round(rouge_output, 4)}
|
121 |
+
|
122 |
+
|
123 |
+
# Load and process datasets
|
124 |
+
train_data = load_and_process_dataset("train", dataset_limit=None)
|
125 |
+
val_data = load_and_process_dataset("validation", dataset_limit=64)
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
# Create and set generation config
|
130 |
+
generation_config = GenerationConfig(
|
131 |
+
# The generated answer/summary should be 100-MAX_ANSWER_LENGTH tokens long
|
132 |
+
max_length=MAX_ANSWER_LENGTH,
|
133 |
+
min_length=100,
|
134 |
+
early_stopping=True,
|
135 |
+
num_beams=4,
|
136 |
+
length_penalty=2.0,
|
137 |
+
# Don't repeat n=3-grams (same words in same order) in the generated text --> more natural
|
138 |
+
no_repeat_ngram_size=3,
|
139 |
+
decoder_start_token_id=tokenizer.cls_token_id,
|
140 |
+
bos_token_id=tokenizer.bos_token_id,
|
141 |
+
)
|
142 |
+
model.generation_config = generation_config
|
143 |
+
|
144 |
+
# Set training arguments
|
145 |
+
training_args = Seq2SeqTrainingArguments(
|
146 |
+
predict_with_generate=True,
|
147 |
+
eval_strategy="steps",
|
148 |
+
per_device_train_batch_size=BATCH_SIZE,
|
149 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
150 |
+
# fp16 only works on GPU, not on M1 mps. mps is used by default if it's available
|
151 |
+
fp16=True,
|
152 |
+
output_dir=f"models/{my_model_name}",
|
153 |
+
logging_steps=50, # 50,
|
154 |
+
eval_steps=500, # 100,
|
155 |
+
save_steps=100, # 100,
|
156 |
+
# warmup_steps=100,
|
157 |
+
save_total_limit=1,
|
158 |
+
gradient_accumulation_steps=1,
|
159 |
+
#num_train_epochs=1,
|
160 |
+
# Save to HF hub & log to wandb
|
161 |
+
push_to_hub=True,
|
162 |
+
hub_model_id=my_model_name,
|
163 |
+
log_level="info",
|
164 |
+
report_to="wandb",
|
165 |
+
run_name=run_name,
|
166 |
+
)
|
167 |
+
|
168 |
+
# start training
|
169 |
+
# Total steps = (num examples in data / (batch size * gradient accumulation steps)) * num epochs
|
170 |
+
# The gradient accumulation adds multiple batches together before updating the weights
|
171 |
+
# https://huggingface.co/docs/transformers/en/perf_train_gpu_one
|
172 |
+
trainer = Seq2SeqTrainer(
|
173 |
+
model=model,
|
174 |
+
processing_class=tokenizer,
|
175 |
+
args=training_args,
|
176 |
+
compute_metrics=compute_metrics,
|
177 |
+
train_dataset=train_data,
|
178 |
+
eval_dataset=val_data,
|
179 |
+
)
|
180 |
+
trainer.train() # resume_from_checkpoint=True)
|
181 |
+
trainer.push_to_hub()
|