import argparse import json import torch import tqdm from transformers import BertTokenizer, BertForSequenceClassification from src.models.DualEncoderModule import DualEncoderModule def triple_to_string(x): return " ".join([item.strip() for item in x]) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Rerank the QA paris and keep top 3 QA paris as evidence using a pre-trained BERT model." ) parser.add_argument( "-i", "--top_k_qa_file", default="data_store/dev_top_k_qa.json", help="Json file with claim and top k generated question-answer pairs.", ) parser.add_argument( "-o", "--output_file", default="data_store/dev_top_3_rerank_qa.json", help="Json file with the top3 reranked questions.", ) parser.add_argument( "-ckpt", "--best_checkpoint", type=str, default="pretrained_models/bert_dual_encoder.ckpt", ) parser.add_argument( "--top_n", type=int, default=3, help="top_n question answer pairs as evidence to keep.", ) args = parser.parse_args() examples = [] with open(args.top_k_qa_file) as f: for line in f: examples.append(json.loads(line)) bert_model_name = "bert-base-uncased" tokenizer = BertTokenizer.from_pretrained(bert_model_name) bert_model = BertForSequenceClassification.from_pretrained( bert_model_name, num_labels=2, problem_type="single_label_classification" ) device = "cuda:0" if torch.cuda.is_available() else "cpu" trained_model = DualEncoderModule.load_from_checkpoint( args.best_checkpoint, tokenizer=tokenizer, model=bert_model ).to(device) with open(args.output_file, "w", encoding="utf-8") as output_file: for example in tqdm.tqdm(examples): strs_to_score = [] values = [] bm25_qau = example["bm25_qau"] if "bm25_qau" in example else [] claim = example["claim"] for question, answer, url in bm25_qau: str_to_score = triple_to_string([claim, question, answer]) strs_to_score.append(str_to_score) values.append([question, answer, url]) if len(bm25_qau) > 0: encoded_dict = tokenizer( strs_to_score, max_length=512, padding="longest", truncation=True, return_tensors="pt", ).to(device) input_ids = encoded_dict["input_ids"] attention_masks = encoded_dict["attention_mask"] scores = torch.softmax( trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1, )[:, 1] top_n = torch.argsort(scores, descending=True)[: args.top_n] evidence = [ { "question": values[i][0], "answer": values[i][1], "url": values[i][2], } for i in top_n ] else: evidence = [] json_data = { "claim_id": example["claim_id"], "claim": claim, "evidence": evidence, } output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n") output_file.flush()