File size: 3,767 Bytes
eaaaf3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import argparse
import json
import tqdm
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from data_loaders.SequenceClassificationDataLoader import (
SequenceClassificationDataLoader,
)
from models.SequenceClassificationModule import SequenceClassificationModule
LABEL = [
"Supported",
"Refuted",
"Not Enough Evidence",
"Conflicting Evidence/Cherrypicking",
]
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model to predict the veracity label."
)
parser.add_argument(
"-i",
"--claim_with_evidence_file",
default="data/dev_top3_questions.json",
help="Json file with claim and top question-answer pairs as evidence.",
)
parser.add_argument(
"-o",
"--output_file",
default="data_store/dev_veracity.json",
help="Json file with the veracity predictions.",
)
parser.add_argument(
"-ckpt",
"--best_checkpoint",
type=str,
default="pretrained_models/bert_veracity.ckpt",
)
args = parser.parse_args()
with open(args.claim_with_evidence_file) as f:
examples = json.load(f)
bert_model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(
bert_model_name, num_labels=4, problem_type="single_label_classification"
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
trained_model = SequenceClassificationModule.load_from_checkpoint(
args.best_checkpoint, tokenizer=tokenizer, model=bert_model
).to(device)
dataLoader = SequenceClassificationDataLoader(
tokenizer=tokenizer,
data_file="this_is_discontinued",
batch_size=32,
add_extra_nee=False,
)
predictions = []
for example in tqdm.tqdm(examples):
example_strings = []
for evidence in example["evidence"]:
example_strings.append(
dataLoader.quadruple_to_string(
example["claim"], evidence["question"], evidence["answer"], ""
)
)
if (
len(example_strings) == 0
): # If we found no evidence e.g. because google returned 0 pages, just output NEI.
example["label"] = "Not Enough Evidence"
continue
tokenized_strings, attention_mask = dataLoader.tokenize_strings(example_strings)
example_support = torch.argmax(
trained_model(tokenized_strings, attention_mask=attention_mask).logits,
axis=1,
)
has_unanswerable = False
has_true = False
has_false = False
for v in example_support:
if v == 0:
has_true = True
if v == 1:
has_false = True
if v in (
2,
3,
): # TODO another hack -- we cant have different labels for train and test so we do this
has_unanswerable = True
if has_unanswerable:
answer = 2
elif has_true and not has_false:
answer = 0
elif not has_true and has_false:
answer = 1
else:
answer = 3
json_data = {
"claim_id": example["claim_id"],
"claim": example["claim"],
"evidence": example["evidence"],
"label": LABEL[answer],
}
predictions.append(json_data)
with open(args.output_file, "w", encoding="utf-8") as output_file:
json.dump(predictions, output_file, ensure_ascii=False, indent=4)
|