stefanbschneider commited on
Commit
6f35c2d
·
verified ·
1 Parent(s): 977471c

add train script

Browse files
Files changed (1) hide show
  1. led-finetune-lfqa-train.py +181 -0
led-finetune-lfqa-train.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import Optional
3
+ from datasets import load_dataset
4
+ import evaluate
5
+ from transformers import (
6
+ Seq2SeqTrainer,
7
+ Seq2SeqTrainingArguments,
8
+ AutoTokenizer,
9
+ AutoModelForSeq2SeqLM,
10
+ GenerationConfig,
11
+ )
12
+ import wandb
13
+
14
+
15
+ # See https://huggingface.co/docs/transformers/en/perf_train_gpu_one
16
+ BATCH_SIZE: int = 2
17
+ # Max allowed answer length in tokens --> select corresponding processed dataset and set allowed decoder len
18
+ MAX_ANSWER_LENGTH: int = 512
19
+
20
+ # initialize wandb for monitoring
21
+ run_name: str = f"vast-gpu_batch-size-{BATCH_SIZE}_ans-len-{MAX_ANSWER_LENGTH}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
22
+ wandb.init(project="led-finetune-lfqa", name=run_name) # type: ignore
23
+
24
+ # load rouge for evaluation
25
+ rouge = evaluate.load("rouge")
26
+
27
+ # load model and tokenizer
28
+ # larger model for better performance, but slower to train: allenai/led-large-16384
29
+ pretrained_model_name = "allenai/led-base-16384"
30
+ my_model_name = f"stefanbschneider/led-base-16384-lfqa-ans-len-{MAX_ANSWER_LENGTH}"
31
+ model_name = my_model_name
32
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
33
+ # Enable gradient checkpointing to reduce memory during training (at the cost of speed)
34
+ model.gradient_checkpointing_enable()
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+
37
+
38
+ def process_data_to_model_inputs(batch):
39
+ # combine context strings and questions to one input
40
+ input = [
41
+ f"question: {question}, context: {' '.join(context)}"
42
+ for question, context in zip(batch["question"], batch["context"])
43
+ ]
44
+
45
+ # tokenize the inputs and labels
46
+ inputs = tokenizer(
47
+ input,
48
+ padding="max_length",
49
+ truncation=True,
50
+ # Max supported article/context length + question.
51
+ max_length=8192,
52
+ )
53
+ outputs = tokenizer(
54
+ batch["answer"],
55
+ padding="max_length",
56
+ truncation=True,
57
+ # Answers in the dataset should be limited to MAX_ANSWER_LENGTH already (see my dataset description)
58
+ max_length=MAX_ANSWER_LENGTH,
59
+ )
60
+
61
+ batch["input_ids"] = inputs.input_ids
62
+ batch["attention_mask"] = inputs.attention_mask
63
+
64
+ # create 0 global_attention_mask lists
65
+ batch["global_attention_mask"] = len(batch["input_ids"]) * [
66
+ [0 for _ in range(len(batch["input_ids"][0]))]
67
+ ]
68
+
69
+ # since above lists are references, the following line changes the 0 index for all samples
70
+ batch["global_attention_mask"][0][0] = 1
71
+ batch["labels"] = outputs.input_ids
72
+
73
+ # We have to make sure that the PAD token is ignored
74
+ batch["labels"] = [
75
+ [-100 if token == tokenizer.pad_token_id else token for token in labels]
76
+ for labels in batch["labels"]
77
+ ]
78
+
79
+ return batch
80
+
81
+
82
+ def load_and_process_dataset(split: str, dataset_limit: Optional[int] = None):
83
+ """Load and process the dataset for training or validation. Optionally limit the number of samples."""
84
+ dataset = load_dataset(f"stefanbschneider/lfqa-max-answer-length-{MAX_ANSWER_LENGTH}", split=split)
85
+
86
+ # optionally reduce the data sets to a small fraction
87
+ if dataset_limit is not None:
88
+ dataset = dataset.select(range(dataset_limit))
89
+
90
+ dataset = dataset.map(
91
+ process_data_to_model_inputs,
92
+ batched=True,
93
+ batch_size=BATCH_SIZE,
94
+ remove_columns=["context", "question", "answer"],
95
+ )
96
+
97
+ dataset.set_format(
98
+ type="torch",
99
+ columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
100
+ )
101
+
102
+ return dataset
103
+
104
+
105
+ def compute_metrics(pred) -> dict[str, float]:
106
+ """Compute rouge score during validation/evaluation"""
107
+ labels_ids = pred.label_ids
108
+ pred_ids = pred.predictions
109
+
110
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
111
+ labels_ids[labels_ids == -100] = tokenizer.pad_token_id
112
+ label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
113
+
114
+ rouge_output = rouge.compute(
115
+ predictions=pred_str, references=label_str, rouge_types=["rouge2"]
116
+ )["rouge2"]
117
+
118
+ # Return rouge2 F1 score
119
+ # There are no longer separate precisoin, recall values in rouge2
120
+ return {"rouge2": round(rouge_output, 4)}
121
+
122
+
123
+ # Load and process datasets
124
+ train_data = load_and_process_dataset("train", dataset_limit=None)
125
+ val_data = load_and_process_dataset("validation", dataset_limit=64)
126
+
127
+
128
+
129
+ # Create and set generation config
130
+ generation_config = GenerationConfig(
131
+ # The generated answer/summary should be 100-MAX_ANSWER_LENGTH tokens long
132
+ max_length=MAX_ANSWER_LENGTH,
133
+ min_length=100,
134
+ early_stopping=True,
135
+ num_beams=4,
136
+ length_penalty=2.0,
137
+ # Don't repeat n=3-grams (same words in same order) in the generated text --> more natural
138
+ no_repeat_ngram_size=3,
139
+ decoder_start_token_id=tokenizer.cls_token_id,
140
+ bos_token_id=tokenizer.bos_token_id,
141
+ )
142
+ model.generation_config = generation_config
143
+
144
+ # Set training arguments
145
+ training_args = Seq2SeqTrainingArguments(
146
+ predict_with_generate=True,
147
+ eval_strategy="steps",
148
+ per_device_train_batch_size=BATCH_SIZE,
149
+ per_device_eval_batch_size=BATCH_SIZE,
150
+ # fp16 only works on GPU, not on M1 mps. mps is used by default if it's available
151
+ fp16=True,
152
+ output_dir=f"models/{my_model_name}",
153
+ logging_steps=50, # 50,
154
+ eval_steps=500, # 100,
155
+ save_steps=100, # 100,
156
+ # warmup_steps=100,
157
+ save_total_limit=1,
158
+ gradient_accumulation_steps=1,
159
+ #num_train_epochs=1,
160
+ # Save to HF hub & log to wandb
161
+ push_to_hub=True,
162
+ hub_model_id=my_model_name,
163
+ log_level="info",
164
+ report_to="wandb",
165
+ run_name=run_name,
166
+ )
167
+
168
+ # start training
169
+ # Total steps = (num examples in data / (batch size * gradient accumulation steps)) * num epochs
170
+ # The gradient accumulation adds multiple batches together before updating the weights
171
+ # https://huggingface.co/docs/transformers/en/perf_train_gpu_one
172
+ trainer = Seq2SeqTrainer(
173
+ model=model,
174
+ processing_class=tokenizer,
175
+ args=training_args,
176
+ compute_metrics=compute_metrics,
177
+ train_dataset=train_data,
178
+ eval_dataset=val_data,
179
+ )
180
+ trainer.train() # resume_from_checkpoint=True)
181
+ trainer.push_to_hub()