potsawee commited on
Commit
e00e573
1 Parent(s): fa3d2c7

Add application file

Browse files
Files changed (3) hide show
  1. app.py +39 -0
  2. question_generation.py +97 -0
  3. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ from question_generation import question_generation_sampling
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ g1_tokenizer = AutoTokenizer.from_pretrained("potsawee/t5-large-generation-squad-QuestionAnswer")
9
+ g1_model = AutoModelForSeq2SeqLM.from_pretrained("potsawee/t5-large-generation-squad-QuestionAnswer")
10
+ g2_tokenizer = AutoTokenizer.from_pretrained("potsawee/t5-large-generation-race-Distractor")
11
+ g2_model = AutoModelForSeq2SeqLM.from_pretrained("potsawee/t5-large-generation-race-Distractor")
12
+ g1_model.eval()
13
+ g2_model.eval()
14
+ g1_model.to(device)
15
+ g2_model.to(device)
16
+
17
+
18
+ def generate_multiple_choice_question(
19
+ context
20
+ ):
21
+ num_questions = 1
22
+ question_item = question_generation_sampling(
23
+ g1_model, g1_tokenizer,
24
+ g2_model, g2_tokenizer,
25
+ context, num_questions, device
26
+ )[0]
27
+ question = question_item['question']
28
+ options = question_item['options']
29
+ options[0] = f"{options[0]} [ANSWER]"
30
+ random.shuffle(options)
31
+ output_string = f"Question: {question}\n[A] {options[0]}\n[B] {options[1]}\n[C] {options[2]}\n[D] {options[3]}"
32
+ return output_string
33
+
34
+ demo = gr.Interface(
35
+ fn=generate_multiple_choice_question,
36
+ inputs=gr.Textbox(lines=5, placeholder="Context Here..."),
37
+ outputs=gr.Textbox(lines=5, placeholder="Question: ...\n[A] ...\n[B] ...\n[C] ...\n[D] ..."),
38
+ )
39
+ demo.launch()
question_generation.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+
4
+ @torch.no_grad()
5
+ def question_generation_sampling(
6
+ g1_model,
7
+ g1_tokenizer,
8
+ g2_model,
9
+ g2_tokenizer,
10
+ context,
11
+ num_questions,
12
+ device,
13
+ ):
14
+ qa_input_ids = prepare_qa_input(
15
+ g1_tokenizer,
16
+ context=context,
17
+ device=device,
18
+ )
19
+ max_repeated_sampling = int(num_questions * 1.5) # sometimes generated question+answer is invalid
20
+ num_valid_questions = 0
21
+ questions = []
22
+ for q_ in range(max_repeated_sampling):
23
+ # Stage G.1: question+answer generation
24
+ outputs = g1_model.generate(
25
+ qa_input_ids,
26
+ max_new_tokens=128,
27
+ do_sample=True,
28
+ )
29
+ question_answer = g1_tokenizer.decode(outputs[0], skip_special_tokens=False)
30
+ question_answer = question_answer.replace(g1_tokenizer.pad_token, "").replace(g1_tokenizer.eos_token, "")
31
+ question_answer_split = question_answer.split(g1_tokenizer.sep_token)
32
+ if len(question_answer_split) == 2:
33
+ # valid Question + Annswer output
34
+ num_valid_questions += 1
35
+ else:
36
+ continue
37
+ question = question_answer_split[0].strip()
38
+ answer = question_answer_split[1].strip()
39
+
40
+ # Stage G.2: Distractor Generation
41
+ distractor_input_ids = prepare_distractor_input(
42
+ g2_tokenizer,
43
+ context = context,
44
+ question = question,
45
+ answer = answer,
46
+ device = device,
47
+ separator = g2_tokenizer.sep_token,
48
+ )
49
+ outputs = g2_model.generate(
50
+ distractor_input_ids,
51
+ max_new_tokens=128,
52
+ do_sample=True,
53
+ )
54
+ distractors = g2_tokenizer.decode(outputs[0], skip_special_tokens=False)
55
+ distractors = distractors.replace(g2_tokenizer.pad_token, "").replace(g2_tokenizer.eos_token, "")
56
+ distractors = re.sub("<extra\S+>", g2_tokenizer.sep_token, distractors)
57
+ distractors = [y.strip() for y in distractors.split(g2_tokenizer.sep_token)]
58
+ options = [answer] + distractors
59
+
60
+ while len(options) < 4:
61
+ options.append(options[-1])
62
+
63
+ question_item = {
64
+ 'question': question,
65
+ 'options': options,
66
+ }
67
+ questions.append(question_item)
68
+ if num_valid_questions == num_questions:
69
+ break
70
+ return questions
71
+
72
+
73
+ def prepare_qa_input(t5_tokenizer, context, device):
74
+ """
75
+ input: context
76
+ output: question <sep> answer
77
+ """
78
+ encoding = t5_tokenizer(
79
+ [context],
80
+ return_tensors="pt",
81
+ )
82
+ input_ids = encoding.input_ids.to(device)
83
+ return input_ids
84
+
85
+
86
+ def prepare_distractor_input(t5_tokenizer, context, question, answer, device, separator='<sep>'):
87
+ """
88
+ input: question <sep> answer <sep> article
89
+ output: distractor1 <sep> distractor2 <sep> distractor3
90
+ """
91
+ input_text = question + ' ' + separator + ' ' + answer + ' ' + separator + ' ' + context
92
+ encoding = t5_tokenizer(
93
+ [input_text],
94
+ return_tensors="pt",
95
+ )
96
+ input_ids = encoding.input_ids.to(device)
97
+ return input_ids
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=1.10
2
+ transformers>=4.11.3