File size: 4,650 Bytes
a8b5bd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import numpy as np
from tqdm import tqdm
import torch
import collections

luke_beam_size = 5
n_best = 30
max_length = 512
stride = 128
batch_size = 8
n_best = 20
max_answer_length = 30

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
luke_model = AutoModelForQuestionAnswering.from_pretrained("botcon/LUKE_squadshift_finetuned_large").to(device)
luke_tokenizer = AutoTokenizer.from_pretrained("roberta-base")

def compute_beam(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answers = sorted(answers, key=lambda x: x["logit_score"], reverse=True)
            best_ans = []
            best_logits = []
            i = 0
            while i < len(best_answers[:luke_beam_size]):
                best_ans.append(best_answers[i]["text"])
                best_logits.append(best_answers[i]["logit_score"])
                i += 1
            while i < luke_beam_size:
                best_ans.append("")
                best_logits.append(1e-5) # treat this as negative infinity
                i += 1

            predicted_answers.append({"id":example_id, "prediction_text": best_ans, "logits": best_logits})
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    return predicted_answers

def preprocess_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = luke_tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )


    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

def generate(dataset):
    luke_model.eval()
    with torch.no_grad():
        preprocessed = dataset.map(
            preprocess_validation_examples,
            batched=True,
            remove_columns=dataset.column_names
        )
        eval_set_for_model = preprocessed.remove_columns(["example_id", "offset_mapping"])
        eval_set_for_model.set_format("torch")
        batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
        outputs = luke_model(**batch)
        start_logits = outputs.start_logits.cpu().numpy()
        end_logits = outputs.end_logits.cpu().numpy()
        res = compute_beam(start_logits, end_logits, preprocessed, dataset)
        return res