AVeriTeC / src /reranking /rerank_questions.py
Chenxi Whitehouse
update src
ce6cd35
import argparse
import json
import torch
import tqdm
from transformers import BertTokenizer, BertForSequenceClassification
from src.models.DualEncoderModule import DualEncoderModule
def triple_to_string(x):
return " </s> ".join([item.strip() for item in x])
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Rerank the QA paris and keep top 3 QA paris as evidence using a pre-trained BERT model."
)
parser.add_argument(
"-i",
"--top_k_qa_file",
default="data_store/dev_top_k_qa.json",
help="Json file with claim and top k generated question-answer pairs.",
)
parser.add_argument(
"-o",
"--output_file",
default="data_store/dev_top_3_rerank_qa.json",
help="Json file with the top3 reranked questions.",
)
parser.add_argument(
"-ckpt",
"--best_checkpoint",
type=str,
default="pretrained_models/bert_dual_encoder.ckpt",
)
parser.add_argument(
"--top_n",
type=int,
default=3,
help="top_n question answer pairs as evidence to keep.",
)
args = parser.parse_args()
examples = []
with open(args.top_k_qa_file) as f:
for line in f:
examples.append(json.loads(line))
bert_model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(
bert_model_name, num_labels=2, problem_type="single_label_classification"
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
trained_model = DualEncoderModule.load_from_checkpoint(
args.best_checkpoint, tokenizer=tokenizer, model=bert_model
).to(device)
with open(args.output_file, "w", encoding="utf-8") as output_file:
for example in tqdm.tqdm(examples):
strs_to_score = []
values = []
bm25_qau = example["bm25_qau"] if "bm25_qau" in example else []
claim = example["claim"]
for question, answer, url in bm25_qau:
str_to_score = triple_to_string([claim, question, answer])
strs_to_score.append(str_to_score)
values.append([question, answer, url])
if len(bm25_qau) > 0:
encoded_dict = tokenizer(
strs_to_score,
max_length=512,
padding="longest",
truncation=True,
return_tensors="pt",
).to(device)
input_ids = encoded_dict["input_ids"]
attention_masks = encoded_dict["attention_mask"]
scores = torch.softmax(
trained_model(input_ids, attention_mask=attention_masks).logits,
axis=-1,
)[:, 1]
top_n = torch.argsort(scores, descending=True)[: args.top_n]
evidence = [
{
"question": values[i][0],
"answer": values[i][1],
"url": values[i][2],
}
for i in top_n
]
else:
evidence = []
json_data = {
"claim_id": example["claim_id"],
"claim": claim,
"evidence": evidence,
}
output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n")
output_file.flush()