pragnakalp commited on
Commit
5c4ebbb
1 Parent(s): d6de82a

Upload questiongenerator.py

Browse files
Files changed (1) hide show
  1. questiongenerator.py +345 -0
questiongenerator.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ import spacy
7
+ import re
8
+ import random
9
+ import json
10
+ import en_core_web_sm
11
+ from string import punctuation
12
+
13
+ #from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
14
+ #from transformers import BertTokenizer, BertForSequenceClassification
15
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
16
+ class QuestionGenerator():
17
+
18
+ def __init__(self, model_dir=None):
19
+
20
+ QG_PRETRAINED = 'iarfmoose/t5-base-question-generator'
21
+ self.ANSWER_TOKEN = '<answer>'
22
+ self.CONTEXT_TOKEN = '<context>'
23
+ self.SEQ_LENGTH = 512
24
+
25
+ self.device = torch.device('cpu')
26
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+
28
+ self.qg_tokenizer = AutoTokenizer.from_pretrained(QG_PRETRAINED)
29
+ self.qg_model = AutoModelForSeq2SeqLM.from_pretrained(QG_PRETRAINED)
30
+ self.qg_model.to(self.device)
31
+
32
+ self.qa_evaluator = QAEvaluator(model_dir)
33
+
34
+ def generate(self, article, use_evaluator=True, num_questions=None, answer_style='all'):
35
+
36
+ print("Generating questions...\n")
37
+
38
+ qg_inputs, qg_answers = self.generate_qg_inputs(article, answer_style)
39
+ print("qg_inputs, qg_answers=>",qg_inputs, qg_answers)
40
+ generated_questions = self.generate_questions_from_inputs(qg_inputs,num_questions)
41
+ print("generated_questions(generate)=>",generated_questions)
42
+ return generated_questions
43
+ message = "{} questions doesn't match {} answers".format(
44
+ len(generated_questions),
45
+ len(qg_answers))
46
+ assert len(generated_questions) == len(qg_answers), message
47
+
48
+ if use_evaluator:
49
+
50
+ print("Evaluating QA pairs...\n")
51
+
52
+ encoded_qa_pairs = self.qa_evaluator.encode_qa_pairs(generated_questions, qg_answers)
53
+ scores = self.qa_evaluator.get_scores(encoded_qa_pairs)
54
+ if num_questions:
55
+ qa_list = self._get_ranked_qa_pairs(generated_questions, qg_answers, scores, num_questions)
56
+ else:
57
+ qa_list = self._get_ranked_qa_pairs(generated_questions, qg_answers, scores)
58
+
59
+ else:
60
+ print("Skipping evaluation step.\n")
61
+ qa_list = self._get_all_qa_pairs(generated_questions, qg_answers)
62
+
63
+ return qa_list
64
+
65
+ def generate_qg_inputs(self, text, answer_style):
66
+
67
+ VALID_ANSWER_STYLES = ['all', 'sentences', 'multiple_choice']
68
+
69
+ if answer_style not in VALID_ANSWER_STYLES:
70
+ raise ValueError(
71
+ "Invalid answer style {}. Please choose from {}".format(
72
+ answer_style,
73
+ VALID_ANSWER_STYLES
74
+ )
75
+ )
76
+
77
+ inputs = []
78
+ answers = []
79
+
80
+ if answer_style == 'sentences' or answer_style == 'all':
81
+ segments = self._split_into_segments(text)
82
+ for segment in segments:
83
+ sentences = self._split_text(segment)
84
+ prepped_inputs, prepped_answers = self._prepare_qg_inputs(sentences, segment)
85
+ inputs.extend(prepped_inputs)
86
+ answers.extend(prepped_answers)
87
+
88
+ if answer_style == 'multiple_choice' or answer_style == 'all':
89
+ sentences = self._split_text(text)
90
+ prepped_inputs, prepped_answers = self._prepare_qg_inputs_MC(sentences)
91
+ inputs.extend(prepped_inputs)
92
+ answers.extend(prepped_answers)
93
+
94
+ return inputs, answers
95
+
96
+ def generate_questions_from_inputs(self, qg_inputs,num_questions):
97
+ generated_questions = []
98
+ count = 0
99
+ print("num que => ", num_questions)
100
+ for qg_input in qg_inputs:
101
+ if count < int(num_questions):
102
+ question = self._generate_question(qg_input)
103
+
104
+ question = question.strip() #remove trailing spaces
105
+ question = question.strip(punctuation) #remove trailing questionmarks
106
+ question += "?" #add one ?
107
+ if question not in generated_questions:
108
+ generated_questions.append(question)
109
+ print("question ===> ",question)
110
+ count += 1
111
+ else:
112
+ return generated_questions
113
+ return generated_questions #
114
+ def _split_text(self, text):
115
+ MAX_SENTENCE_LEN = 128
116
+
117
+ sentences = re.findall('.*?[.!\?]', text)
118
+
119
+ cut_sentences = []
120
+ for sentence in sentences:
121
+ if len(sentence) > MAX_SENTENCE_LEN:
122
+ cut_sentences.extend(re.split('[,;:)]', sentence))
123
+ # temporary solution to remove useless post-quote sentence fragments
124
+ cut_sentences = [s for s in sentences if len(s.split(" ")) > 5]
125
+ sentences = sentences + cut_sentences
126
+
127
+ return list(set([s.strip(" ") for s in sentences]))
128
+
129
+ def _split_into_segments(self, text):
130
+ MAX_TOKENS = 490
131
+
132
+ paragraphs = text.split('\n')
133
+ tokenized_paragraphs = [self.qg_tokenizer(p)['input_ids'] for p in paragraphs if len(p) > 0]
134
+
135
+ segments = []
136
+ while len(tokenized_paragraphs) > 0:
137
+ segment = []
138
+ while len(segment) < MAX_TOKENS and len(tokenized_paragraphs) > 0:
139
+ paragraph = tokenized_paragraphs.pop(0)
140
+ segment.extend(paragraph)
141
+ segments.append(segment)
142
+ return [self.qg_tokenizer.decode(s) for s in segments]
143
+
144
+ def _prepare_qg_inputs(self, sentences, text):
145
+ inputs = []
146
+ answers = []
147
+
148
+ for sentence in sentences:
149
+ qg_input = '{} {} {} {}'.format(
150
+ self.ANSWER_TOKEN,
151
+ sentence,
152
+ self.CONTEXT_TOKEN,
153
+ text
154
+ )
155
+ inputs.append(qg_input)
156
+ answers.append(sentence)
157
+
158
+ return inputs, answers
159
+
160
+ def _prepare_qg_inputs_MC(self, sentences):
161
+
162
+ spacy_nlp = en_core_web_sm.load()
163
+ docs = list(spacy_nlp.pipe(sentences, disable=['parser']))
164
+ inputs_from_text = []
165
+ answers_from_text = []
166
+
167
+ for i in range(len(sentences)):
168
+ entities = docs[i].ents
169
+ if entities:
170
+ for entity in entities:
171
+ qg_input = '{} {} {} {}'.format(
172
+ self.ANSWER_TOKEN,
173
+ entity,
174
+ self.CONTEXT_TOKEN,
175
+ sentences[i]
176
+ )
177
+ answers = self._get_MC_answers(entity, docs)
178
+ inputs_from_text.append(qg_input)
179
+ answers_from_text.append(answers)
180
+
181
+ return inputs_from_text, answers_from_text
182
+
183
+ def _get_MC_answers(self, correct_answer, docs):
184
+
185
+ entities = []
186
+ for doc in docs:
187
+ entities.extend([{'text': e.text, 'label_': e.label_} for e in doc.ents])
188
+
189
+ # remove duplicate elements
190
+ entities_json = [json.dumps(kv) for kv in entities]
191
+ pool = set(entities_json)
192
+ num_choices = min(4, len(pool)) - 1 # -1 because we already have the correct answer
193
+
194
+ # add the correct answer
195
+ final_choices = []
196
+ correct_label = correct_answer.label_
197
+ final_choices.append({'answer': correct_answer.text, 'correct': True})
198
+ pool.remove(json.dumps({'text': correct_answer.text, 'label_': correct_answer.label_}))
199
+
200
+ # find answers with the same NER label
201
+ matches = [e for e in pool if correct_label in e]
202
+
203
+ # if we don't have enough then add some other random answers
204
+ if len(matches) < num_choices:
205
+ choices = matches
206
+ pool = pool.difference(set(choices))
207
+ choices.extend(random.sample(pool, num_choices - len(choices)))
208
+ else:
209
+ choices = random.sample(matches, num_choices)
210
+
211
+ choices = [json.loads(s) for s in choices]
212
+ for choice in choices:
213
+ final_choices.append({'answer': choice['text'], 'correct': False})
214
+ random.shuffle(final_choices)
215
+ return final_choices
216
+
217
+ def _generate_question(self, qg_input):
218
+ self.qg_model.eval()
219
+ encoded_input = self._encode_qg_input(qg_input)
220
+ with torch.no_grad():
221
+ output = self.qg_model.generate(input_ids=encoded_input['input_ids'])
222
+ return self.qg_tokenizer.decode(output[0])
223
+
224
+ def _encode_qg_input(self, qg_input):
225
+ return self.qg_tokenizer(
226
+ qg_input,
227
+ pad_to_max_length=True,
228
+ max_length=self.SEQ_LENGTH,
229
+ truncation=True,
230
+ return_tensors="pt"
231
+ ).to(self.device)
232
+
233
+ def _get_ranked_qa_pairs(self, generated_questions, qg_answers, scores, num_questions=10):
234
+ if num_questions > len(scores):
235
+ num_questions = len(scores)
236
+ print("\nWas only able to generate {} questions. For more questions, please input a longer text.".format(num_questions))
237
+
238
+ qa_list = []
239
+ for i in range(num_questions):
240
+ index = scores[i]
241
+ qa = self._make_dict(
242
+ generated_questions[index].split('?')[0] + '?',
243
+ qg_answers[index])
244
+ qa_list.append(qa)
245
+ return qa_list
246
+
247
+ def _get_all_qa_pairs(self, generated_questions, qg_answers):
248
+ qa_list = []
249
+ for i in range(len(generated_questions)):
250
+ qa = self._make_dict(
251
+ generated_questions[i].split('?')[0] + '?',
252
+ qg_answers[i])
253
+ qa_list.append(qa)
254
+ return qa_list
255
+
256
+ def _make_dict(self, question, answer):
257
+ qa = {}
258
+ qa['question'] = question
259
+ qa['answer'] = answer
260
+ return qa
261
+
262
+
263
+ class QAEvaluator():
264
+ def __init__(self, model_dir=None):
265
+
266
+ QAE_PRETRAINED = 'iarfmoose/bert-base-cased-qa-evaluator'
267
+ self.SEQ_LENGTH = 512
268
+
269
+ self.device = torch.device('cpu')
270
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
271
+
272
+ self.qae_tokenizer = AutoTokenizer.from_pretrained(QAE_PRETRAINED)
273
+ self.qae_model = AutoModelForSequenceClassification.from_pretrained(QAE_PRETRAINED)
274
+ self.qae_model.to(self.device)
275
+
276
+
277
+ def encode_qa_pairs(self, questions, answers):
278
+ encoded_pairs = []
279
+ for i in range(len(questions)):
280
+ encoded_qa = self._encode_qa(questions[i], answers[i])
281
+ encoded_pairs.append(encoded_qa.to(self.device))
282
+ return encoded_pairs
283
+
284
+ def get_scores(self, encoded_qa_pairs):
285
+ scores = {}
286
+ self.qae_model.eval()
287
+ with torch.no_grad():
288
+ for i in range(len(encoded_qa_pairs)):
289
+ scores[i] = self._evaluate_qa(encoded_qa_pairs[i])
290
+
291
+ return [k for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)]
292
+
293
+ def _encode_qa(self, question, answer):
294
+ if type(answer) is list:
295
+ for a in answer:
296
+ if a['correct']:
297
+ correct_answer = a['answer']
298
+ else:
299
+ correct_answer = answer
300
+ return self.qae_tokenizer(
301
+ text=question,
302
+ text_pair=correct_answer,
303
+ pad_to_max_length=True,
304
+ max_length=self.SEQ_LENGTH,
305
+ truncation=True,
306
+ return_tensors="pt"
307
+ )
308
+
309
+ def _evaluate_qa(self, encoded_qa_pair):
310
+ output = self.qae_model(**encoded_qa_pair)
311
+ return output[0][0][1]
312
+
313
+
314
+ def print_qa(qa_list, show_answers=True):
315
+ for i in range(len(qa_list)):
316
+ space = ' ' * int(np.where(i < 9, 3, 4)) # wider space for 2 digit q nums
317
+
318
+ print('{}) Q: {}'.format(i + 1, qa_list[i]['question']))
319
+
320
+ answer = qa_list[i]['answer']
321
+
322
+ # print a list of multiple choice answers
323
+ if type(answer) is list:
324
+
325
+ if show_answers:
326
+ print('{}A: 1.'.format(space),
327
+ answer[0]['answer'],
328
+ np.where(answer[0]['correct'], '(correct)', ''))
329
+ for j in range(1, len(answer)):
330
+ print('{}{}.'.format(space + ' ', j + 1),
331
+ answer[j]['answer'],
332
+ np.where(answer[j]['correct'] == True, '(correct)', ''))
333
+
334
+ else:
335
+ print('{}A: 1.'.format(space),
336
+ answer[0]['answer'])
337
+ for j in range(1, len(answer)):
338
+ print('{}{}.'.format(space + ' ', j + 1),
339
+ answer[j]['answer'])
340
+ print('')
341
+
342
+ # print full sentence answers
343
+ else:
344
+ if show_answers:
345
+ print('{}A:'.format(space), answer, '\n')