File size: 3,113 Bytes
06f0d67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import re

@torch.no_grad()
def question_generation_sampling(
    g1_model,
    g1_tokenizer,
    g2_model,
    g2_tokenizer,
    context,
    num_questions,
    device,
):
    qa_input_ids = prepare_qa_input(
            g1_tokenizer,
            context=context,
            device=device,
    )
    max_repeated_sampling = int(num_questions * 1.5) # sometimes generated question+answer is invalid
    num_valid_questions = 0
    questions = []
    for q_ in range(max_repeated_sampling):
        # Stage G.1: question+answer generation
        outputs = g1_model.generate(
            qa_input_ids,
            max_new_tokens=128,
            do_sample=True,
        )
        question_answer = g1_tokenizer.decode(outputs[0], skip_special_tokens=False)
        question_answer = question_answer.replace(g1_tokenizer.pad_token, "").replace(g1_tokenizer.eos_token, "")
        question_answer_split = question_answer.split(g1_tokenizer.sep_token)
        if len(question_answer_split) == 2:
            # valid Question + Annswer output
            num_valid_questions += 1
        else:
            continue
        question = question_answer_split[0].strip()
        answer = question_answer_split[1].strip()

        # Stage G.2: Distractor Generation
        distractor_input_ids = prepare_distractor_input(
            g2_tokenizer,
            context = context,
            question = question,
            answer = answer,
            device = device,
            separator = g2_tokenizer.sep_token,
        )
        outputs = g2_model.generate(
            distractor_input_ids,
            max_new_tokens=128,
            do_sample=True,
        )
        distractors = g2_tokenizer.decode(outputs[0], skip_special_tokens=False)
        distractors = distractors.replace(g2_tokenizer.pad_token, "").replace(g2_tokenizer.eos_token, "")
        distractors = re.sub("<extra\S+>", g2_tokenizer.sep_token, distractors)
        distractors = [y.strip() for y in distractors.split(g2_tokenizer.sep_token)]
        options = [answer] + distractors

        while len(options) < 4:
            options.append(options[-1])

        question_item = {
            'question': question,
            'options': options,
        }
        questions.append(question_item)
        if num_valid_questions == num_questions:
            break
    return questions


def prepare_qa_input(t5_tokenizer, context, device):
    """
    input: context
    output: question <sep> answer
    """
    encoding = t5_tokenizer(
        [context],
        return_tensors="pt",
    )
    input_ids = encoding.input_ids.to(device)
    return input_ids


def prepare_distractor_input(t5_tokenizer, context, question, answer, device, separator='<sep>'):
    """
    input: question <sep> answer <sep> article
    output: distractor1 <sep> distractor2 <sep> distractor3
    """
    input_text = question + ' ' + separator + ' ' + answer + ' ' + separator + ' ' + context
    encoding = t5_tokenizer(
        [input_text],
        return_tensors="pt",
    )
    input_ids = encoding.input_ids.to(device)
    return input_ids