AVeriTeC / src /reranking /rerank_questions.py
Chenxi Whitehouse
update src
ce6cd35
raw
history blame
No virus
3.54 kB
import argparse
import json
import torch
import tqdm
from transformers import BertTokenizer, BertForSequenceClassification
from src.models.DualEncoderModule import DualEncoderModule
def triple_to_string(x):
return " </s> ".join([item.strip() for item in x])
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Rerank the QA paris and keep top 3 QA paris as evidence using a pre-trained BERT model."
)
parser.add_argument(
"-i",
"--top_k_qa_file",
default="data_store/dev_top_k_qa.json",
help="Json file with claim and top k generated question-answer pairs.",
)
parser.add_argument(
"-o",
"--output_file",
default="data_store/dev_top_3_rerank_qa.json",
help="Json file with the top3 reranked questions.",
)
parser.add_argument(
"-ckpt",
"--best_checkpoint",
type=str,
default="pretrained_models/bert_dual_encoder.ckpt",
)
parser.add_argument(
"--top_n",
type=int,
default=3,
help="top_n question answer pairs as evidence to keep.",
)
args = parser.parse_args()
examples = []
with open(args.top_k_qa_file) as f:
for line in f:
examples.append(json.loads(line))
bert_model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(
bert_model_name, num_labels=2, problem_type="single_label_classification"
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
trained_model = DualEncoderModule.load_from_checkpoint(
args.best_checkpoint, tokenizer=tokenizer, model=bert_model
).to(device)
with open(args.output_file, "w", encoding="utf-8") as output_file:
for example in tqdm.tqdm(examples):
strs_to_score = []
values = []
bm25_qau = example["bm25_qau"] if "bm25_qau" in example else []
claim = example["claim"]
for question, answer, url in bm25_qau:
str_to_score = triple_to_string([claim, question, answer])
strs_to_score.append(str_to_score)
values.append([question, answer, url])
if len(bm25_qau) > 0:
encoded_dict = tokenizer(
strs_to_score,
max_length=512,
padding="longest",
truncation=True,
return_tensors="pt",
).to(device)
input_ids = encoded_dict["input_ids"]
attention_masks = encoded_dict["attention_mask"]
scores = torch.softmax(
trained_model(input_ids, attention_mask=attention_masks).logits,
axis=-1,
)[:, 1]
top_n = torch.argsort(scores, descending=True)[: args.top_n]
evidence = [
{
"question": values[i][0],
"answer": values[i][1],
"url": values[i][2],
}
for i in top_n
]
else:
evidence = []
json_data = {
"claim_id": example["claim_id"],
"claim": claim,
"evidence": evidence,
}
output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n")
output_file.flush()