botcon commited on
Commit
e556f98
1 Parent(s): 4a8a37c

Upload QuestionAnswering.py

Browse files
Files changed (1) hide show
  1. QuestionAnswering.py +446 -0
QuestionAnswering.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, AutoModelForQuestionAnswering
2
+ from transformers.modeling_outputs import ModelOutput
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import evaluate
8
+ import torch
9
+ from dataclasses import dataclass
10
+ from datasets import load_dataset
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+ import collections
14
+
15
+ PEFT = False
16
+ tf32 = True
17
+ fp16= True
18
+ train = False
19
+ test = True
20
+ trained_model = "SpanBERT_squad_finetuned_qa"
21
+ train_checkpoint = None
22
+
23
+ # base_tokenizer = "roberta-base"
24
+ # base_model = "studio-ousia/luke-base"
25
+
26
+ # base_tokenizer = "xlnet-base-cased"
27
+ # base_model = "xlnet-base-cased"
28
+
29
+ base_tokenizer = "bert-base-cased"
30
+ base_model = "SpanBERT/spanbert-base-cased"
31
+
32
+ torch.backends.cuda.matmul.allow_tf32 = tf32
33
+ torch.backends.cudnn.allow_tf32 = tf32
34
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
35
+
36
+ if tf32:
37
+ trained_model += "_tf32"
38
+
39
+ # https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/luke/modeling_luke.py#L319-L353
40
+ # Taken from HF repository, easier to include additional features -- Currently identical to LukeForQuestionAnswering by HF
41
+
42
+ @dataclass
43
+ class LukeQuestionAnsweringModelOutput(ModelOutput):
44
+ """
45
+ Outputs of question answering models.
46
+
47
+
48
+ Args:
49
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
50
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
51
+ start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
52
+ Span-start scores (before SoftMax).
53
+ end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
54
+ Span-end scores (before SoftMax).
55
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
56
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
57
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
58
+
59
+
60
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
61
+ entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
62
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
63
+ shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
64
+ layer plus the initial entity embedding outputs.
65
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
66
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
67
+ sequence_length)`.
68
+
69
+
70
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
71
+ heads.
72
+ """
73
+
74
+
75
+ loss: Optional[torch.FloatTensor] = None
76
+ start_logits: torch.FloatTensor = None
77
+ end_logits: torch.FloatTensor = None
78
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
79
+ entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
80
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
81
+
82
+ class AugmentedLukeForQuestionAnswering(LukePreTrainedModel):
83
+ def __init__(self, config):
84
+ super().__init__(config)
85
+
86
+ # This is 2.
87
+ self.num_labels = config.num_labels
88
+
89
+ self.luke = LukeModel(config, add_pooling_layer=False)
90
+
91
+ '''
92
+ Any improvement to the model are expected here. Additional features, anything...
93
+ '''
94
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
95
+
96
+
97
+ # Initialize weights and apply final processing
98
+ self.post_init()
99
+
100
+ def forward(
101
+ self,
102
+ input_ids: Optional[torch.LongTensor] = None,
103
+ attention_mask: Optional[torch.FloatTensor] = None,
104
+ token_type_ids: Optional[torch.LongTensor] = None,
105
+ position_ids: Optional[torch.FloatTensor] = None,
106
+ entity_ids: Optional[torch.LongTensor] = None,
107
+ entity_attention_mask: Optional[torch.FloatTensor] = None,
108
+ entity_token_type_ids: Optional[torch.LongTensor] = None,
109
+ entity_position_ids: Optional[torch.LongTensor] = None,
110
+ head_mask: Optional[torch.FloatTensor] = None,
111
+ inputs_embeds: Optional[torch.FloatTensor] = None,
112
+ start_positions: Optional[torch.LongTensor] = None,
113
+ end_positions: Optional[torch.LongTensor] = None,
114
+ output_attentions: Optional[bool] = None,
115
+ output_hidden_states: Optional[bool] = None,
116
+ return_dict: Optional[bool] = None,
117
+ ) -> Union[Tuple, LukeQuestionAnsweringModelOutput]:
118
+
119
+ r"""
120
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
121
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
122
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
123
+ are not taken into account for computing the loss.
124
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
125
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
126
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
127
+ are not taken into account for computing the loss.
128
+ """
129
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
130
+
131
+
132
+ outputs = self.luke(
133
+ input_ids=input_ids,
134
+ attention_mask=attention_mask,
135
+ token_type_ids=token_type_ids,
136
+ position_ids=position_ids,
137
+ entity_ids=entity_ids,
138
+ entity_attention_mask=entity_attention_mask,
139
+ entity_token_type_ids=entity_token_type_ids,
140
+ entity_position_ids=entity_position_ids,
141
+ head_mask=head_mask,
142
+ inputs_embeds=inputs_embeds,
143
+ output_attentions=output_attentions,
144
+ output_hidden_states=output_hidden_states,
145
+ return_dict=True,
146
+ )
147
+
148
+
149
+ sequence_output = outputs.last_hidden_state
150
+
151
+
152
+ logits = self.qa_outputs(sequence_output)
153
+ start_logits, end_logits = logits.split(1, dim=-1)
154
+ start_logits = start_logits.squeeze(-1)
155
+ end_logits = end_logits.squeeze(-1)
156
+
157
+
158
+ total_loss = None
159
+ if start_positions is not None and end_positions is not None:
160
+ # If we are on multi-GPU, split add a dimension
161
+ if len(start_positions.size()) > 1:
162
+ start_positions = start_positions.squeeze(-1)
163
+ if len(end_positions.size()) > 1:
164
+ end_positions = end_positions.squeeze(-1)
165
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
166
+ ignored_index = start_logits.size(1)
167
+ start_positions.clamp_(0, ignored_index)
168
+ end_positions.clamp_(0, ignored_index)
169
+
170
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
171
+ start_loss = loss_fct(start_logits, start_positions)
172
+ end_loss = loss_fct(end_logits, end_positions)
173
+ total_loss = (start_loss + end_loss) / 2
174
+
175
+
176
+ if not return_dict:
177
+ return tuple(
178
+ v
179
+ for v in [
180
+ total_loss,
181
+ start_logits,
182
+ end_logits,
183
+ outputs.hidden_states,
184
+ outputs.entity_hidden_states,
185
+ outputs.attentions,
186
+ ]
187
+ if v is not None
188
+ )
189
+
190
+
191
+ return LukeQuestionAnsweringModelOutput(
192
+ loss=total_loss,
193
+ start_logits=start_logits,
194
+ end_logits=end_logits,
195
+ hidden_states=outputs.hidden_states,
196
+ entity_hidden_states=outputs.entity_hidden_states,
197
+ attentions=outputs.attentions,
198
+ )
199
+
200
+ if __name__ == "__main__":
201
+ # Setting up tokenizer and helper functions
202
+ # Work-around for FastTokenizer - RoBERTa and LUKE share the same subword vocab, and we are not using entities functions of LUKE-tokenizer anyways
203
+ tokenizer = AutoTokenizer.from_pretrained(base_tokenizer)
204
+
205
+ # Necessary initialization
206
+ max_length = 384
207
+ stride = 128
208
+ batch_size = 8
209
+ n_best = 20
210
+ max_answer_length = 30
211
+ metric = evaluate.load("squad")
212
+ raw_datasets = load_dataset("squad")
213
+
214
+ def compute_metrics(start_logits, end_logits, features, examples):
215
+ example_to_features = collections.defaultdict(list)
216
+ for idx, feature in enumerate(features):
217
+ example_to_features[feature["example_id"]].append(idx)
218
+
219
+ predicted_answers = []
220
+ for example in tqdm(examples):
221
+ example_id = example["id"]
222
+ context = example["context"]
223
+ answers = []
224
+
225
+ # Loop through all features associated with that example
226
+ for feature_index in example_to_features[example_id]:
227
+ start_logit = start_logits[feature_index]
228
+ end_logit = end_logits[feature_index]
229
+ offsets = features[feature_index]["offset_mapping"]
230
+
231
+ start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
232
+ end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
233
+ for start_index in start_indexes:
234
+ for end_index in end_indexes:
235
+ # Skip answers that are not fully in the context
236
+ if offsets[start_index] is None or offsets[end_index] is None:
237
+ continue
238
+ # Skip answers with a length that is either < 0 or > max_answer_length
239
+ if (
240
+ end_index < start_index
241
+ or end_index - start_index + 1 > max_answer_length
242
+ ):
243
+ continue
244
+
245
+ answer = {
246
+ "text": context[offsets[start_index][0] : offsets[end_index][1]],
247
+ "logit_score": start_logit[start_index] + end_logit[end_index],
248
+ }
249
+ answers.append(answer)
250
+
251
+ # Select the answer with the best score
252
+ if len(answers) > 0:
253
+ best_answer = max(answers, key=lambda x: x["logit_score"])
254
+ predicted_answers.append(
255
+ {"id": example_id, "prediction_text": best_answer["text"]}
256
+ )
257
+ else:
258
+ predicted_answers.append({"id": example_id, "prediction_text": ""})
259
+
260
+ theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
261
+ return metric.compute(predictions=predicted_answers, references=theoretical_answers)
262
+
263
+ def preprocess_training_examples(examples):
264
+
265
+ questions = [q.strip() for q in examples["question"]]
266
+ inputs = tokenizer(
267
+ questions,
268
+ examples["context"],
269
+ max_length=max_length,
270
+ truncation="only_second",
271
+ stride=stride,
272
+ return_overflowing_tokens=True,
273
+ return_offsets_mapping=True,
274
+ padding="max_length",
275
+ )
276
+
277
+ offset_mapping = inputs.pop("offset_mapping")
278
+ sample_map = inputs.pop("overflow_to_sample_mapping")
279
+ answers = examples["answers"]
280
+ start_positions = []
281
+ end_positions = []
282
+
283
+ for i, offset in enumerate(offset_mapping):
284
+ sample_idx = sample_map[i]
285
+ answer = answers[sample_idx]
286
+ start_char = answer["answer_start"][0]
287
+ end_char = answer["answer_start"][0] + len(answer["text"][0])
288
+ sequence_ids = inputs.sequence_ids(i)
289
+
290
+ # Find the start and end of the context
291
+ idx = 0
292
+ while sequence_ids[idx] != 1:
293
+ idx += 1
294
+ context_start = idx
295
+ while sequence_ids[idx] == 1:
296
+ idx += 1
297
+ context_end = idx - 1
298
+
299
+ # If the answer is not fully inside the context, label is (0, 0)
300
+ if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
301
+ start_positions.append(0)
302
+ end_positions.append(0)
303
+ else:
304
+ # Otherwise it's the start and end token positions
305
+ idx = context_start
306
+ while idx <= context_end and offset[idx][0] <= start_char:
307
+ idx += 1
308
+ start_positions.append(idx - 1)
309
+
310
+ idx = context_end
311
+ while idx >= context_start and offset[idx][1] >= end_char:
312
+ idx -= 1
313
+ end_positions.append(idx + 1)
314
+
315
+ inputs["start_positions"] = start_positions
316
+ inputs["end_positions"] = end_positions
317
+ return inputs
318
+
319
+ def preprocess_validation_examples(examples):
320
+ questions = [q.strip() for q in examples["question"]]
321
+ inputs = tokenizer(
322
+ questions,
323
+ examples["context"],
324
+ max_length=max_length,
325
+ truncation="only_second",
326
+ stride=stride,
327
+ return_overflowing_tokens=True,
328
+ return_offsets_mapping=True,
329
+ padding="max_length",
330
+ )
331
+
332
+
333
+ sample_map = inputs.pop("overflow_to_sample_mapping")
334
+ example_ids = []
335
+
336
+ for i in range(len(inputs["input_ids"])):
337
+ sample_idx = sample_map[i]
338
+ example_ids.append(examples["id"][sample_idx])
339
+
340
+ sequence_ids = inputs.sequence_ids(i)
341
+ offset = inputs["offset_mapping"][i]
342
+ inputs["offset_mapping"][i] = [
343
+ o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
344
+ ]
345
+
346
+ inputs["example_id"] = example_ids
347
+ return inputs
348
+
349
+ if train:
350
+
351
+ model = AutoModelForQuestionAnswering.from_pretrained(base_model).to(device)
352
+
353
+ train_dataset = raw_datasets["train"].map(
354
+ preprocess_training_examples,
355
+ batched=True,
356
+ remove_columns=raw_datasets["train"].column_names,
357
+ )
358
+
359
+ validation_dataset = raw_datasets["validation"].map(
360
+ preprocess_validation_examples,
361
+ batched=True,
362
+ remove_columns=raw_datasets["validation"].column_names,
363
+ )
364
+
365
+ # --------------- PEFT -------------------- # One epoch without PEFT took about 2h on my computer with CUDA - performance of PEFT kinda ass though
366
+ if PEFT:
367
+ from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
368
+
369
+ # ---- For all linear layers ----
370
+ import re
371
+ pattern = r'\((\w+)\): Linear'
372
+ linear_layers = re.findall(pattern, str(model.modules))
373
+ target_modules = list(set(linear_layers))
374
+
375
+ # If using peft, can consider increaisng r for better performance
376
+ peft_config = LoraConfig(
377
+ task_type=TaskType.QUESTION_ANS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=target_modules, bias='all'
378
+ )
379
+
380
+ model = get_peft_model(model, peft_config)
381
+ model.print_trainable_parameters()
382
+
383
+ trained_model += "_PEFT"
384
+
385
+ # ------------------------------------------ #
386
+
387
+ args = TrainingArguments(
388
+ trained_model,
389
+ evaluation_strategy = "no",
390
+ save_strategy="epoch",
391
+ learning_rate=2e-5,
392
+ per_device_train_batch_size=batch_size,
393
+ per_device_eval_batch_size=batch_size,
394
+ num_train_epochs=3,
395
+ weight_decay=0.01,
396
+ push_to_hub=True,
397
+ fp16=fp16
398
+ )
399
+
400
+ trainer = Trainer(
401
+ model,
402
+ args,
403
+ train_dataset=train_dataset,
404
+ eval_dataset=validation_dataset,
405
+ data_collator=default_data_collator,
406
+ tokenizer=tokenizer
407
+ )
408
+
409
+ trainer.train(train_checkpoint)
410
+
411
+ elif test:
412
+ model = AutoModelForQuestionAnswering.from_pretrained(trained_model).to(device)
413
+
414
+ interval = len(raw_datasets["validation"]) // 100
415
+ exact_match = 0
416
+ f1 = 0
417
+
418
+ with torch.no_grad():
419
+ for i in range(1, 101):
420
+ start = interval * (i - 1)
421
+ end = interval * i
422
+ small_eval_set = raw_datasets["validation"].select(range(start ,end))
423
+ eval_set = small_eval_set.map(
424
+ preprocess_validation_examples,
425
+ batched=True,
426
+ remove_columns=raw_datasets["validation"].column_names
427
+ )
428
+ eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
429
+ eval_set_for_model.set_format("torch")
430
+ batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
431
+ outputs = model(**batch)
432
+ start_logits = outputs.start_logits.cpu().numpy()
433
+ end_logits = outputs.end_logits.cpu().numpy()
434
+ res = compute_metrics(start_logits, end_logits, eval_set, small_eval_set)
435
+ exact_match += res['exact_match']
436
+ f1 += res["f1"]
437
+ print("F1 score: {}".format(f1 / 100))
438
+ print("Exact match: {}".format(exact_match / 100))
439
+
440
+ # XLNET
441
+ # F1 score: 91.54154256653278
442
+ # Exact match: 84.86666666666666
443
+
444
+ # SpanBERT
445
+ # F1 score: 92.160285362531
446
+ # Exact match: 85.73333333333333