File size: 3,543 Bytes
eaaaf3d
 
 
 
 
ce6cd35
eaaaf3d
 
 
 
 
 
 
 
 
 
 
 
 
2b4f5ff
eaaaf3d
 
 
 
 
4cac25e
eaaaf3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cac25e
eaaaf3d
4cac25e
 
eaaaf3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import json
import torch
import tqdm
from transformers import BertTokenizer, BertForSequenceClassification
from src.models.DualEncoderModule import DualEncoderModule


def triple_to_string(x):
    return " </s> ".join([item.strip() for item in x])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Rerank the QA paris and keep top 3 QA paris as evidence using a pre-trained BERT model."
    )
    parser.add_argument(
        "-i",
        "--top_k_qa_file",
        default="data_store/dev_top_k_qa.json",
        help="Json file with claim and top k generated question-answer pairs.",
    )
    parser.add_argument(
        "-o",
        "--output_file",
        default="data_store/dev_top_3_rerank_qa.json",
        help="Json file with the top3 reranked questions.",
    )
    parser.add_argument(
        "-ckpt",
        "--best_checkpoint",
        type=str,
        default="pretrained_models/bert_dual_encoder.ckpt",
    )
    parser.add_argument(
        "--top_n",
        type=int,
        default=3,
        help="top_n question answer pairs as evidence to keep.",
    )
    args = parser.parse_args()

    examples = []
    with open(args.top_k_qa_file) as f:
        for line in f:
            examples.append(json.loads(line))

    bert_model_name = "bert-base-uncased"

    tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    bert_model = BertForSequenceClassification.from_pretrained(
        bert_model_name, num_labels=2, problem_type="single_label_classification"
    )
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    trained_model = DualEncoderModule.load_from_checkpoint(
        args.best_checkpoint, tokenizer=tokenizer, model=bert_model
    ).to(device)

    with open(args.output_file, "w", encoding="utf-8") as output_file:
        for example in tqdm.tqdm(examples):
            strs_to_score = []
            values = []

            bm25_qau = example["bm25_qau"] if "bm25_qau" in example else []
            claim = example["claim"]

            for question, answer, url in bm25_qau:
                str_to_score = triple_to_string([claim, question, answer])

                strs_to_score.append(str_to_score)
                values.append([question, answer, url])

            if len(bm25_qau) > 0:
                encoded_dict = tokenizer(
                    strs_to_score,
                    max_length=512,
                    padding="longest",
                    truncation=True,
                    return_tensors="pt",
                ).to(device)

                input_ids = encoded_dict["input_ids"]
                attention_masks = encoded_dict["attention_mask"]

                scores = torch.softmax(
                    trained_model(input_ids, attention_mask=attention_masks).logits,
                    axis=-1,
                )[:, 1]

                top_n = torch.argsort(scores, descending=True)[: args.top_n]
                evidence = [
                    {
                        "question": values[i][0],
                        "answer": values[i][1],
                        "url": values[i][2],
                    }
                    for i in top_n
                ]
            else:
                evidence = []

            json_data = {
                "claim_id": example["claim_id"],
                "claim": claim,
                "evidence": evidence,
            }
            output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n")
            output_file.flush()