Chenxi Whitehouse commited on
Commit
eaaaf3d
·
1 Parent(s): 4400462

add src files

Browse files
README.md CHANGED
@@ -47,7 +47,7 @@ The training and dev dataset can be found under [data](https://huggingface.co/ch
47
 
48
  ## Reproduce the baseline
49
 
50
- Below are the steps to reproduce the baseline results. The main difference from the reported results in the paper is that, instead of requiring direct access to the paid Google Search API, we provide such search results for up to 1000 URLs per claim using different queries, and the scraped text as a knowledge store for retrieval for each claim. This is aimed at reducing the overhead cost of participating in the Shared Task.
51
 
52
 
53
  ### 0. Set up environment
@@ -93,28 +93,35 @@ python -m src.reranking.bm25_sentences
93
  ```
94
 
95
  ### 3. Generate questions-answer pair for the top sentences
96
- We use [BLOOM](https://huggingface.co/bigscience/bloom-7b1) to generate QA paris for each of the top 100 sentence, providing 10 closest claim-QA-pairs from the training set as in-context examples. See [question_generation_top_sentences.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/reranking/question_generation_top_sentences.py) for more argument options. We provide the output file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/ddev_top_k_qa.json).
97
  ```bash
98
  python -m src.reranking.question_generation_top_sentences
99
  ```
100
 
101
  ### 4. Rerank the QA pairs
102
- Using a pre-trained BERT model [bert_dual_encoder.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_dual_encoder.ckpt), we rerank the QA paris and keep top 3 QA paris as evidence. We provide the output file for this step on the dev set [here]().
103
  ```bash
 
104
  ```
105
 
106
 
107
  ### 5. Veracity prediction
108
- Finally, given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model [bert_veracity.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_veracity.ckpt) to predict the veracity label. The pre-trained model is provided . We provide the prediction file for this step on the dev set [here]().
109
  ```bash
 
110
  ```
111
  The results will be presented as follows:
112
- ```bash
 
113
  ```
114
 
115
- We recommend using 0.25 as cut-off score for evaluating the relevance of the evidence. The result for dev and the test set below.
116
 
 
117
 
 
 
 
 
118
 
119
  ## Citation
120
  If you find AVeriTeC useful for your research and applications, please cite us using this BibTeX:
 
47
 
48
  ## Reproduce the baseline
49
 
50
+ Below are the steps to reproduce the baseline results. The main difference from the reported results in the paper is that, instead of requiring direct access to the paid Google Search API, we provide such search results for up to 1000 URLs per claim using different queries, and the scraped text as a knowledge store for retrieval for each claim. This is aimed at reducing the overhead cost of participating in the Shared Task. Another difference is that we also added text scraped from pdf URLs to the knowledge store.
51
 
52
 
53
  ### 0. Set up environment
 
93
  ```
94
 
95
  ### 3. Generate questions-answer pair for the top sentences
96
+ We use [BLOOM](https://huggingface.co/bigscience/bloom-7b1) to generate QA paris for each of the top 100 sentence, providing 10 closest claim-QA-pairs from the training set as in-context examples. See [question_generation_top_sentences.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/reranking/question_generation_top_sentences.py) for more argument options. We provide the output file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/dev_top_k_qa.json).
97
  ```bash
98
  python -m src.reranking.question_generation_top_sentences
99
  ```
100
 
101
  ### 4. Rerank the QA pairs
102
+ Using a pre-trained BERT model [bert_dual_encoder.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_dual_encoder.ckpt), we rerank the QA paris and keep top 3 QA paris as evidence. See [rerank_questions.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/reranking/rerank_questions.py) for more argument options. We provide the output file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/dev_top_3_rerank_qa.json).
103
  ```bash
104
+ python -m reranking.rerank_questions
105
  ```
106
 
107
 
108
  ### 5. Veracity prediction
109
+ Finally, given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model [bert_veracity.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_veracity.ckpt) to predict the veracity label. See [veracity_prediction.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/prediction/veracity_prediction.py) for more argument options. We provide the prediction file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/dev_vericity_prediction.json).
110
  ```bash
111
+ python -m prediction.veracity_prediction
112
  ```
113
  The results will be presented as follows:
114
+
115
+ ```
116
  ```
117
 
 
118
 
119
+ We recommend using 0.25 as cut-off score for evaluating the relevance of the evidence. The result for dev and the test set below.
120
 
121
+ | Model | Split | Q only | Q + A | Veracity @ 0.2 | @ 0.25 | @ 0.3 |
122
+ |-------------------|-------|--------|-------|----------------|--------|-------|
123
+ | AVeriTeC-BLOOM-7b | dev | | | | | |
124
+ | AVeriTeC-BLOOM-7b | test | | | | | |
125
 
126
  ## Citation
127
  If you find AVeriTeC useful for your research and applications, please cite us using this BibTeX:
src/prediction/veracity_prediction.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import tqdm
4
+ import torch
5
+ from transformers import BertTokenizer, BertForSequenceClassification
6
+ from data_loaders.SequenceClassificationDataLoader import (
7
+ SequenceClassificationDataLoader,
8
+ )
9
+ from models.SequenceClassificationModule import SequenceClassificationModule
10
+
11
+
12
+ LABEL = [
13
+ "Supported",
14
+ "Refuted",
15
+ "Not Enough Evidence",
16
+ "Conflicting Evidence/Cherrypicking",
17
+ ]
18
+
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser(
22
+ description="Given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model to predict the veracity label."
23
+ )
24
+ parser.add_argument(
25
+ "-i",
26
+ "--claim_with_evidence_file",
27
+ default="data/dev_top3_questions.json",
28
+ help="Json file with claim and top question-answer pairs as evidence.",
29
+ )
30
+ parser.add_argument(
31
+ "-o",
32
+ "--output_file",
33
+ default="data_store/dev_veracity.json",
34
+ help="Json file with the veracity predictions.",
35
+ )
36
+ parser.add_argument(
37
+ "-ckpt",
38
+ "--best_checkpoint",
39
+ type=str,
40
+ default="pretrained_models/bert_veracity.ckpt",
41
+ )
42
+ args = parser.parse_args()
43
+
44
+ with open(args.claim_with_evidence_file) as f:
45
+ examples = json.load(f)
46
+
47
+ bert_model_name = "bert-base-uncased"
48
+
49
+ tokenizer = BertTokenizer.from_pretrained(bert_model_name)
50
+ bert_model = BertForSequenceClassification.from_pretrained(
51
+ bert_model_name, num_labels=4, problem_type="single_label_classification"
52
+ )
53
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
54
+ trained_model = SequenceClassificationModule.load_from_checkpoint(
55
+ args.best_checkpoint, tokenizer=tokenizer, model=bert_model
56
+ ).to(device)
57
+
58
+ dataLoader = SequenceClassificationDataLoader(
59
+ tokenizer=tokenizer,
60
+ data_file="this_is_discontinued",
61
+ batch_size=32,
62
+ add_extra_nee=False,
63
+ )
64
+
65
+ predictions = []
66
+
67
+ for example in tqdm.tqdm(examples):
68
+ example_strings = []
69
+ for evidence in example["evidence"]:
70
+ example_strings.append(
71
+ dataLoader.quadruple_to_string(
72
+ example["claim"], evidence["question"], evidence["answer"], ""
73
+ )
74
+ )
75
+
76
+ if (
77
+ len(example_strings) == 0
78
+ ): # If we found no evidence e.g. because google returned 0 pages, just output NEI.
79
+ example["label"] = "Not Enough Evidence"
80
+ continue
81
+
82
+ tokenized_strings, attention_mask = dataLoader.tokenize_strings(example_strings)
83
+ example_support = torch.argmax(
84
+ trained_model(tokenized_strings, attention_mask=attention_mask).logits,
85
+ axis=1,
86
+ )
87
+
88
+ has_unanswerable = False
89
+ has_true = False
90
+ has_false = False
91
+
92
+ for v in example_support:
93
+ if v == 0:
94
+ has_true = True
95
+ if v == 1:
96
+ has_false = True
97
+ if v in (
98
+ 2,
99
+ 3,
100
+ ): # TODO another hack -- we cant have different labels for train and test so we do this
101
+ has_unanswerable = True
102
+
103
+ if has_unanswerable:
104
+ answer = 2
105
+ elif has_true and not has_false:
106
+ answer = 0
107
+ elif not has_true and has_false:
108
+ answer = 1
109
+ else:
110
+ answer = 3
111
+
112
+ json_data = {
113
+ "claim_id": example["claim_id"],
114
+ "claim": example["claim"],
115
+ "evidence": example["evidence"],
116
+ "label": LABEL[answer],
117
+ }
118
+ predictions.append(json_data)
119
+
120
+ with open(args.output_file, "w", encoding="utf-8") as output_file:
121
+ json.dump(predictions, output_file, ensure_ascii=False, indent=4)
src/reranking/bm25_sentences.py CHANGED
@@ -30,7 +30,7 @@ def retrieve_top_k_sentences(query, document, urls, top_k):
30
  if __name__ == "__main__":
31
 
32
  parser = argparse.ArgumentParser(
33
- description="Get top 100 sentences for sentences in the knowledge store"
34
  )
35
  parser.add_argument(
36
  "-k",
 
30
  if __name__ == "__main__":
31
 
32
  parser = argparse.ArgumentParser(
33
+ description="Get top 100 sentences with BM25 in the knowledge store."
34
  )
35
  parser.add_argument(
36
  "-k",
src/reranking/rerank_questions.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import torch
4
+ import tqdm
5
+ from transformers import BertTokenizer, BertForSequenceClassification
6
+ from models.DualEncoderModule import DualEncoderModule
7
+
8
+
9
+ def triple_to_string(x):
10
+ return " </s> ".join([item.strip() for item in x])
11
+
12
+
13
+ if __name__ == "__main__":
14
+ parser = argparse.ArgumentParser(
15
+ description="Rerank the QA paris and keep top 3 QA paris as evidence using a pre-trained BERT model."
16
+ )
17
+ parser.add_argument(
18
+ "-i",
19
+ "--top_k_qa_file",
20
+ default="data/dev_top_k_qa.json",
21
+ help="Json file with claim and top k generated question-answer pairs.",
22
+ )
23
+ parser.add_argument(
24
+ "-o",
25
+ "--output_file",
26
+ default="data/dev_top_3_rerank_qa.json",
27
+ help="Json file with the top3 reranked questions.",
28
+ )
29
+ parser.add_argument(
30
+ "-ckpt",
31
+ "--best_checkpoint",
32
+ type=str,
33
+ default="pretrained_models/bert_dual_encoder.ckpt",
34
+ )
35
+ parser.add_argument(
36
+ "--top_n",
37
+ type=int,
38
+ default=3,
39
+ help="top_n question answer pairs as evidence to keep.",
40
+ )
41
+ args = parser.parse_args()
42
+
43
+ with open(args.top_k_qa_file) as f:
44
+ examples = json.load(f)
45
+
46
+ bert_model_name = "bert-base-uncased"
47
+
48
+ tokenizer = BertTokenizer.from_pretrained(bert_model_name)
49
+ bert_model = BertForSequenceClassification.from_pretrained(
50
+ bert_model_name, num_labels=2, problem_type="single_label_classification"
51
+ )
52
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
53
+ trained_model = DualEncoderModule.load_from_checkpoint(
54
+ args.best_checkpoint, tokenizer=tokenizer, model=bert_model
55
+ ).to(device)
56
+
57
+ with open(args.output_file, "w", encoding="utf-8") as output_file:
58
+ for example in tqdm.tqdm(examples):
59
+ strs_to_score = []
60
+ values = []
61
+
62
+ bm25_qau = example["bm25_qau"] if "bm25_qau" in example else []
63
+ claim = example["claim"]
64
+
65
+ for question, answer, url in bm25_qau:
66
+ str_to_score = triple_to_string([claim, question, answer])
67
+
68
+ strs_to_score.append(str_to_score)
69
+ values.append([question, answer, url])
70
+
71
+ if len(bm25_qau) > 0:
72
+ encoded_dict = tokenizer(
73
+ strs_to_score,
74
+ max_length=512,
75
+ padding="longest",
76
+ truncation=True,
77
+ return_tensors="pt",
78
+ ).to(device)
79
+
80
+ input_ids = encoded_dict["input_ids"]
81
+ attention_masks = encoded_dict["attention_mask"]
82
+
83
+ scores = torch.softmax(
84
+ trained_model(input_ids, attention_mask=attention_masks).logits,
85
+ axis=-1,
86
+ )[:, 1]
87
+
88
+ top_n = torch.argsort(scores, descending=True)[: args.top_n]
89
+ evidence = [
90
+ {
91
+ "question": values[i][0],
92
+ "answer": values[i][1],
93
+ "url": values[i][2],
94
+ }
95
+ for i in top_n
96
+ ]
97
+ else:
98
+ evidence = []
99
+
100
+ json_data = {
101
+ "claim_id": example["claim_id"],
102
+ "claim": claim,
103
+ "evidence": evidence,
104
+ }
105
+ output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n")
106
+ output_file.flush()
src/retrieval/scraper_for_knowledge_store.py CHANGED
@@ -46,7 +46,7 @@ def scrape_text_from_url(url, temp_name):
46
 
47
  if __name__ == "__main__":
48
 
49
- parser = argparse.ArgumentParser(description="Scraping text from URL")
50
  parser.add_argument(
51
  "-i",
52
  "--tsv_input_file",
 
46
 
47
  if __name__ == "__main__":
48
 
49
+ parser = argparse.ArgumentParser(description="Scraping text from URLs.")
50
  parser.add_argument(
51
  "-i",
52
  "--tsv_input_file",