Spaces:
Build error
Build error
Achyut Tiwari
commited on
Add files via upload
Browse files- util/common.py +120 -0
- util/create_dpr_training_from_dataset.py +103 -0
- util/create_dpr_training_from_faiss.py +144 -0
- util/create_faiss_index.py +67 -0
- util/eval_generate.py +140 -0
- util/kilt_create_dpr_support_docs.py +109 -0
- util/query_smoke_test.py +81 -0
util/common.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
|
6 |
+
'wikidata_info', 'history']
|
7 |
+
|
8 |
+
kilt_wikipedia_paragraph_columns = ['wikipedia_id', 'start_paragraph_id', 'start_character', 'end_paragraph_id',
|
9 |
+
'end_character', 'title', 'section', 'text']
|
10 |
+
|
11 |
+
|
12 |
+
def clean_question(text):
|
13 |
+
result = cleanup_references(text)
|
14 |
+
result = result.replace("\n", " ")
|
15 |
+
result = re.sub(r"\s\s+", " ", result)
|
16 |
+
result = result.replace("[deleted]", "")
|
17 |
+
return result.lower().strip()
|
18 |
+
|
19 |
+
|
20 |
+
def cleanup_references(text):
|
21 |
+
# URL reference where we need to remove both the link text and URL
|
22 |
+
# ...and this letter is used by most biographers as the cornerstone of Lee's personal
|
23 |
+
# views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
|
24 |
+
# ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
|
25 |
+
result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
|
26 |
+
|
27 |
+
# URL reference where we need to preserve link text but remove URL
|
28 |
+
# At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
|
29 |
+
# At the outbreak of the Civil War, Leyburn left his church and joined the South.
|
30 |
+
result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
|
31 |
+
|
32 |
+
# lastly remove just dangling _URL_[0-9]_ URL references
|
33 |
+
result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
|
34 |
+
return result
|
35 |
+
|
36 |
+
|
37 |
+
def clean_answer(text):
|
38 |
+
result = cleanup_references(text)
|
39 |
+
result = result.replace("\n", " ")
|
40 |
+
result = re.sub(r"\s\s+", " ", result)
|
41 |
+
result = re.sub(r"BULLET::::-", "", result)
|
42 |
+
return trim(result.strip())
|
43 |
+
|
44 |
+
|
45 |
+
def trim(text, word_count: int = 100):
|
46 |
+
return " ".join(text.split(" ")[:word_count])
|
47 |
+
|
48 |
+
|
49 |
+
def articles_to_paragraphs(examples):
|
50 |
+
ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], []
|
51 |
+
for bidx, example in enumerate(examples["text"]):
|
52 |
+
last_section = ""
|
53 |
+
for idx, p in enumerate(example["paragraph"]):
|
54 |
+
if "Section::::" in p:
|
55 |
+
last_section = p
|
56 |
+
ids.append(examples["wikipedia_id"][bidx])
|
57 |
+
titles.append(examples["wikipedia_title"][bidx])
|
58 |
+
sections.append(last_section)
|
59 |
+
texts.append(p)
|
60 |
+
start_ps.append(idx)
|
61 |
+
end_ps.append(idx)
|
62 |
+
start_cs.append(0)
|
63 |
+
end_cs.append(len(p))
|
64 |
+
|
65 |
+
return {"wikipedia_id": ids, "title": titles,
|
66 |
+
"section": sections, "text": texts,
|
67 |
+
"start_paragraph_id": start_ps, "end_paragraph_id": end_ps,
|
68 |
+
"start_character": start_cs,
|
69 |
+
"end_character": end_cs
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
def create_kilt_datapoint(eli5_example, columns, wiki_passages, min_length=20, topk=7):
|
74 |
+
res_list = [dict([(k, p[k]) for k in columns]) for p in wiki_passages]
|
75 |
+
res_list = [res for res in res_list if len(res["text"].split()) > min_length][:topk]
|
76 |
+
|
77 |
+
# make a KILT data point
|
78 |
+
# see https://github.com/facebookresearch/KILT#kilt-data-format
|
79 |
+
output = []
|
80 |
+
for a in eli5_example["answers"]["text"]:
|
81 |
+
output.append({"answer": a})
|
82 |
+
|
83 |
+
output.append({"provenance": [
|
84 |
+
# evidence set for the answer from the KILT ks
|
85 |
+
{
|
86 |
+
"wikipedia_id": r["wikipedia_id"], # *mandatory*
|
87 |
+
"title": r["title"],
|
88 |
+
"section": r["section"],
|
89 |
+
"start_paragraph_id": r["start_paragraph_id"],
|
90 |
+
"start_character": r["start_character"],
|
91 |
+
"end_paragraph_id": r["end_paragraph_id"],
|
92 |
+
"end_character": r["end_character"],
|
93 |
+
"text": r["text"],
|
94 |
+
"bleu_score": None, # wrt original evidence
|
95 |
+
"meta": None # dataset/task specific
|
96 |
+
} for r in res_list
|
97 |
+
]})
|
98 |
+
return {"id": eli5_example["q_id"],
|
99 |
+
"input": eli5_example["title"],
|
100 |
+
"output": output, # each element is an answer or provenance (can have multiple of each)
|
101 |
+
"meta": None # dataset/task specific
|
102 |
+
}
|
103 |
+
|
104 |
+
|
105 |
+
def embed_questions(question_model, question_tokenizer, questions, max_length=128, device="cuda:0"):
|
106 |
+
query = question_tokenizer(questions, max_length=max_length, padding="max_length", truncation=True,
|
107 |
+
return_tensors="pt")
|
108 |
+
with torch.no_grad():
|
109 |
+
q_reps = question_model(query["input_ids"].to(device),
|
110 |
+
query["attention_mask"].to(device)).pooler_output
|
111 |
+
return q_reps.cpu().numpy()
|
112 |
+
|
113 |
+
|
114 |
+
def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cuda:0"):
|
115 |
+
p = ctx_tokenizer(passages["text"], max_length=max_length, padding="max_length",
|
116 |
+
truncation=True, return_tensors="pt")
|
117 |
+
with torch.no_grad():
|
118 |
+
a_reps = ctx_model(p["input_ids"].to(device),
|
119 |
+
p["attention_mask"].to(device)).pooler_output
|
120 |
+
return {"embeddings": a_reps.cpu().numpy()}
|
util/create_dpr_training_from_dataset.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from sentence_transformers.util import semantic_search, cos_sim
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
from datasets import load_dataset
|
10 |
+
|
11 |
+
from common import clean_answer, clean_question
|
12 |
+
|
13 |
+
|
14 |
+
def find_hard_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
|
15 |
+
exclude_answer_patterns, similarity_threshold=[0.5, 0.6], k=25, min_count=3):
|
16 |
+
hard_negative_ctxs = []
|
17 |
+
results = semantic_search(dataset_embeddings[embedding_index], dataset_embeddings, top_k=k,
|
18 |
+
score_function=cos_sim)
|
19 |
+
# list if dicts
|
20 |
+
# [{'corpus_id': 8, 'score': -0.019427383318543434},
|
21 |
+
# ...
|
22 |
+
# {'corpus_id': 10, 'score': -0.09040290117263794}]
|
23 |
+
# hard negative are most similar and negatives are most disimilar to embedding_index
|
24 |
+
hard_negative_results = results[0][1:k + 1]
|
25 |
+
assert len(hard_negative_results) > min_count * 2
|
26 |
+
for r in hard_negative_results:
|
27 |
+
example = dataset[r["corpus_id"]]
|
28 |
+
if similarity_threshold[0] < r["score"] <= similarity_threshold[1]:
|
29 |
+
for a in example["answers"]["text"]:
|
30 |
+
hard_negative_ctxs.append({"title": "", "text": clean_answer(a)})
|
31 |
+
if len(hard_negative_ctxs) > min_count:
|
32 |
+
break
|
33 |
+
return hard_negative_ctxs[:min_count]
|
34 |
+
|
35 |
+
|
36 |
+
def find_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
|
37 |
+
exclude_answer_patterns, similarity_threshold=0.1, k=7, min_count=3):
|
38 |
+
negative_ctxs = []
|
39 |
+
random_sample = random.sample(range(len(dataset_embeddings)), k * 20)
|
40 |
+
similarities = cos_sim(dataset_embeddings[embedding_index], dataset_embeddings[random_sample])[0].tolist()
|
41 |
+
for idx, score in enumerate(similarities):
|
42 |
+
if score < similarity_threshold:
|
43 |
+
example = dataset[random_sample[idx]]
|
44 |
+
for a in example["answers"]["text"]:
|
45 |
+
negative_ctxs.append({"title": "", "text": clean_answer(a)})
|
46 |
+
if len(negative_ctxs) > min_count:
|
47 |
+
break
|
48 |
+
return negative_ctxs[:min_count]
|
49 |
+
|
50 |
+
|
51 |
+
def generate_dpr_training_file(args):
|
52 |
+
embedder = SentenceTransformer(args.embedding_model)
|
53 |
+
|
54 |
+
eli5_train_set = load_dataset("vblagoje/lfqa", split="train")
|
55 |
+
eli5_validation_set = load_dataset("vblagoje/lfqa", split="validation")
|
56 |
+
eli5_test_set = load_dataset("vblagoje/lfqa", split="test")
|
57 |
+
|
58 |
+
train_set = embedder.encode([example["title"] for example in eli5_train_set], convert_to_tensor=True,
|
59 |
+
show_progress_bar=True)
|
60 |
+
validation_set = embedder.encode([example["title"] for example in eli5_validation_set], convert_to_tensor=True,
|
61 |
+
show_progress_bar=True)
|
62 |
+
|
63 |
+
test_set = embedder.encode([example["title"] for example in eli5_test_set], convert_to_tensor=True,
|
64 |
+
show_progress_bar=True)
|
65 |
+
exclude_answer_patterns = [re.compile("not sure what you"), re.compile("\n\n >")]
|
66 |
+
for dataset_name, dataset, dataset_embeddings in zip(["train", "validation", "test"],
|
67 |
+
[eli5_train_set, eli5_validation_set, eli5_test_set],
|
68 |
+
[train_set, validation_set, test_set]):
|
69 |
+
min_elements = 3
|
70 |
+
skip_count = 0
|
71 |
+
progress_bar = tqdm(range(len(dataset)), desc="Creating DPR formatted question/passage docs")
|
72 |
+
with open('eli5-dpr-' + dataset_name + '.jsonl', 'w') as fp:
|
73 |
+
for idx, example in enumerate(dataset):
|
74 |
+
negative_ctxs = find_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
|
75 |
+
hard_negative_ctxs = find_hard_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
|
76 |
+
positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"] if
|
77 |
+
not any([p.search(a) for p in exclude_answer_patterns])]
|
78 |
+
if not positive_context:
|
79 |
+
positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"]]
|
80 |
+
if len(positive_context) > 0 and len(negative_ctxs) > 0 and len(hard_negative_ctxs) >= min_elements:
|
81 |
+
json.dump({"id": example["q_id"],
|
82 |
+
"question": clean_question(example["title"]),
|
83 |
+
"positive_ctxs": positive_context[:min_elements],
|
84 |
+
"negative_ctxs": negative_ctxs[:min_elements],
|
85 |
+
"hard_negative_ctxs": hard_negative_ctxs[:min_elements]}, fp)
|
86 |
+
fp.write("\n")
|
87 |
+
else:
|
88 |
+
skip_count += 1
|
89 |
+
progress_bar.update(1)
|
90 |
+
|
91 |
+
print(f"Skipped {skip_count} questions")
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
parser = argparse.ArgumentParser(description="Creates DPR training file from LFQA dataset")
|
96 |
+
parser.add_argument(
|
97 |
+
"--embedding_model",
|
98 |
+
default="all-mpnet-base-v2",
|
99 |
+
help="Embedding model to use for question encoding and semantic search",
|
100 |
+
)
|
101 |
+
|
102 |
+
main_args, _ = parser.parse_known_args()
|
103 |
+
generate_dpr_training_file(main_args)
|
util/create_dpr_training_from_faiss.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from datasets import load_dataset
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
8 |
+
from transformers import DPRQuestionEncoder
|
9 |
+
|
10 |
+
from common import embed_questions, clean_question, articles_to_paragraphs, kilt_wikipedia_columns
|
11 |
+
from common import kilt_wikipedia_paragraph_columns as columns
|
12 |
+
|
13 |
+
|
14 |
+
def generate_dpr_training_file(args):
|
15 |
+
n_negatives = 7
|
16 |
+
min_chars_per_passage = 200
|
17 |
+
|
18 |
+
def query_index(question, topk=(n_negatives * args.n_positives) * 2):
|
19 |
+
question_embedding = embed_questions(question_model, question_tokenizer, [question])
|
20 |
+
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
|
21 |
+
|
22 |
+
retrieved_examples = []
|
23 |
+
r = list(zip(wiki_passages[k] for k in columns))
|
24 |
+
for i in range(topk):
|
25 |
+
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
|
26 |
+
|
27 |
+
return retrieved_examples
|
28 |
+
|
29 |
+
def find_positive_and_hard_negative_ctxs(dataset_index: int, n_positive=1, device="cuda:0"):
|
30 |
+
positive_context_list = []
|
31 |
+
hard_negative_context_list = []
|
32 |
+
example = dataset[dataset_index]
|
33 |
+
question = clean_question(example['title'])
|
34 |
+
passages = query_index(question)
|
35 |
+
passages = [dict([(k, p[k]) for k in columns]) for p in passages]
|
36 |
+
q_passage_pairs = [[question, f"{p['title']} {p['text']}" if args.use_title else p["text"]] for p in passages]
|
37 |
+
|
38 |
+
features = ce_tokenizer(q_passage_pairs, padding="max_length", max_length=256, truncation=True,
|
39 |
+
return_tensors="pt")
|
40 |
+
with torch.no_grad():
|
41 |
+
passage_scores = ce_model(features["input_ids"].to(device),
|
42 |
+
features["attention_mask"].to(device)).logits
|
43 |
+
|
44 |
+
for p_idx, p in enumerate(passages):
|
45 |
+
p["score"] = passage_scores[p_idx].item()
|
46 |
+
|
47 |
+
# order by scores
|
48 |
+
def score_passage(item):
|
49 |
+
return item["score"]
|
50 |
+
|
51 |
+
# pick the most relevant as the positive answer
|
52 |
+
best_passage_list = sorted(passages, key=score_passage, reverse=True)
|
53 |
+
for idx, item in enumerate(best_passage_list):
|
54 |
+
if idx < n_positive:
|
55 |
+
positive_context_list.append({"title": item["title"], "text": item["text"]})
|
56 |
+
else:
|
57 |
+
break
|
58 |
+
|
59 |
+
# least relevant as hard_negative
|
60 |
+
worst_passage_list = sorted(passages, key=score_passage, reverse=False)
|
61 |
+
for idx, hard_negative in enumerate(worst_passage_list):
|
62 |
+
if idx < n_negatives * n_positive:
|
63 |
+
hard_negative_context_list.append({"title": hard_negative["title"], "text": hard_negative["text"]})
|
64 |
+
else:
|
65 |
+
break
|
66 |
+
assert len(positive_context_list) * n_negatives == len(hard_negative_context_list)
|
67 |
+
return positive_context_list, hard_negative_context_list
|
68 |
+
|
69 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
70 |
+
|
71 |
+
question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
|
72 |
+
question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
|
73 |
+
_ = question_model.eval()
|
74 |
+
|
75 |
+
ce_model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-MiniLM-L-4-v2').to(device)
|
76 |
+
ce_tokenizer = AutoTokenizer.from_pretrained('cross-encoder/ms-marco-MiniLM-L-4-v2')
|
77 |
+
_ = ce_model.eval()
|
78 |
+
|
79 |
+
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
|
80 |
+
|
81 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
|
82 |
+
remove_columns=kilt_wikipedia_columns,
|
83 |
+
batch_size=512,
|
84 |
+
cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
|
85 |
+
desc="Expanding wiki articles into paragraphs")
|
86 |
+
|
87 |
+
# use paragraphs that are not simple fragments or very short sentences
|
88 |
+
# Wikipedia Faiss index needs to fit into a 16 Gb GPU
|
89 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
|
90 |
+
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
|
91 |
+
|
92 |
+
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
|
93 |
+
|
94 |
+
eli5_train_set = load_dataset("vblagoje/lfqa", split="train")
|
95 |
+
eli5_validation_set = load_dataset("vblagoje/lfqa", split="validation")
|
96 |
+
eli5_test_set = load_dataset("vblagoje/lfqa", split="test")
|
97 |
+
|
98 |
+
for dataset_name, dataset in zip(["train", "validation", "test"], [eli5_train_set,
|
99 |
+
eli5_validation_set,
|
100 |
+
eli5_test_set]):
|
101 |
+
|
102 |
+
progress_bar = tqdm(range(len(dataset)), desc=f"Creating DPR formatted {dataset_name} file")
|
103 |
+
with open('eli5-dpr-' + dataset_name + '.jsonl', 'w') as fp:
|
104 |
+
for idx, example in enumerate(dataset):
|
105 |
+
negative_start_idx = 0
|
106 |
+
positive_context, hard_negative_ctxs = find_positive_and_hard_negative_ctxs(idx, args.n_positives,
|
107 |
+
device)
|
108 |
+
for pc in positive_context:
|
109 |
+
hnc = hard_negative_ctxs[negative_start_idx:negative_start_idx + n_negatives]
|
110 |
+
json.dump({"id": example["q_id"],
|
111 |
+
"question": clean_question(example["title"]),
|
112 |
+
"positive_ctxs": [pc],
|
113 |
+
"hard_negative_ctxs": hnc}, fp)
|
114 |
+
fp.write("\n")
|
115 |
+
negative_start_idx += n_negatives
|
116 |
+
progress_bar.update(1)
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
parser = argparse.ArgumentParser(description="Creates DPR training file")
|
121 |
+
parser.add_argument(
|
122 |
+
"--use_title",
|
123 |
+
action="store_true",
|
124 |
+
help="If true, use title in addition to passage text for passage embedding",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--n_positives",
|
128 |
+
default=3,
|
129 |
+
help="Number of positive samples per question",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--question_encoder_name",
|
133 |
+
default="vblagoje/dpr-question_encoder-single-lfqa-base",
|
134 |
+
help="Question encoder to use",
|
135 |
+
)
|
136 |
+
|
137 |
+
parser.add_argument(
|
138 |
+
"--index_file_name",
|
139 |
+
default="../data/kilt_dpr_wikipedia_first.faiss",
|
140 |
+
help="Faiss index with passage embeddings",
|
141 |
+
)
|
142 |
+
|
143 |
+
main_args, _ = parser.parse_known_args()
|
144 |
+
generate_dpr_training_file(main_args)
|
util/create_faiss_index.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import faiss
|
5 |
+
import torch
|
6 |
+
from datasets import load_dataset
|
7 |
+
from transformers import AutoTokenizer, DPRContextEncoder
|
8 |
+
|
9 |
+
from common import articles_to_paragraphs, embed_passages
|
10 |
+
|
11 |
+
|
12 |
+
def create_faiss(args):
|
13 |
+
dims = 128
|
14 |
+
min_chars_per_passage = 200
|
15 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
|
17 |
+
ctx_tokenizer = AutoTokenizer.from_pretrained(args.ctx_encoder_name)
|
18 |
+
ctx_model = DPRContextEncoder.from_pretrained(args.ctx_encoder_name).to(device)
|
19 |
+
_ = ctx_model.eval()
|
20 |
+
|
21 |
+
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
|
22 |
+
kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
|
23 |
+
'wikidata_info', 'history']
|
24 |
+
|
25 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
|
26 |
+
remove_columns=kilt_wikipedia_columns,
|
27 |
+
batch_size=512,
|
28 |
+
cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
|
29 |
+
desc="Expanding wiki articles into paragraphs")
|
30 |
+
|
31 |
+
# use paragraphs that are not simple fragments or very short sentences
|
32 |
+
# Wikipedia Faiss index needs to fit into a 16 Gb GPU
|
33 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
|
34 |
+
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
|
35 |
+
|
36 |
+
if not os.path.isfile(args.index_file_name):
|
37 |
+
def embed_passages_for_retrieval(examples):
|
38 |
+
return embed_passages(ctx_model, ctx_tokenizer, examples, max_length=128)
|
39 |
+
|
40 |
+
paragraphs_embeddings = kilt_wikipedia_paragraphs.map(embed_passages_for_retrieval,
|
41 |
+
batched=True, batch_size=512,
|
42 |
+
cache_file_name="../data/kilt_embedded.arrow",
|
43 |
+
desc="Creating faiss index")
|
44 |
+
|
45 |
+
paragraphs_embeddings.add_faiss_index(column="embeddings", custom_index=faiss.IndexFlatIP(dims))
|
46 |
+
paragraphs_embeddings.save_faiss_index("embeddings", args.index_file_name)
|
47 |
+
else:
|
48 |
+
print(f"Faiss index already exists {args.index_file_name}")
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file")
|
53 |
+
|
54 |
+
parser.add_argument(
|
55 |
+
"--ctx_encoder_name",
|
56 |
+
default="vblagoje/dpr-ctx_encoder-single-lfqa-base",
|
57 |
+
help="Encoding model to use for passage encoding",
|
58 |
+
)
|
59 |
+
|
60 |
+
parser.add_argument(
|
61 |
+
"--index_file_name",
|
62 |
+
default="../data/kilt_dpr_wikipedia.faiss",
|
63 |
+
help="Faiss index file with passage embeddings",
|
64 |
+
)
|
65 |
+
|
66 |
+
main_args, _ = parser.parse_known_args()
|
67 |
+
create_faiss(main_args)
|
util/eval_generate.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from datasets import load_dataset
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder
|
9 |
+
|
10 |
+
from common import articles_to_paragraphs, kilt_wikipedia_columns
|
11 |
+
from common import kilt_wikipedia_paragraph_columns as columns
|
12 |
+
|
13 |
+
|
14 |
+
def eval_generate(args):
|
15 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
|
17 |
+
question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
|
18 |
+
_ = question_model.eval()
|
19 |
+
|
20 |
+
eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5')
|
21 |
+
eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device)
|
22 |
+
_ = eli5_model.eval()
|
23 |
+
|
24 |
+
min_snippet_length = 20
|
25 |
+
topk = 21
|
26 |
+
min_chars_per_passage = 200
|
27 |
+
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
|
28 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
|
29 |
+
remove_columns=kilt_wikipedia_columns,
|
30 |
+
batch_size=256,
|
31 |
+
cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow",
|
32 |
+
desc="Expanding wiki articles into paragraphs")
|
33 |
+
|
34 |
+
# use paragraphs that are not simple fragments or very short sentences
|
35 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
|
36 |
+
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
|
37 |
+
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
|
38 |
+
|
39 |
+
def embed_questions_for_retrieval(questions):
|
40 |
+
query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
|
41 |
+
with torch.no_grad():
|
42 |
+
q_reps = question_model(query["input_ids"].to(device),
|
43 |
+
query["attention_mask"].to(device)).pooler_output
|
44 |
+
return q_reps.cpu().numpy()
|
45 |
+
|
46 |
+
def query_index(question):
|
47 |
+
question_embedding = embed_questions_for_retrieval([question])
|
48 |
+
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
|
49 |
+
|
50 |
+
retrieved_examples = []
|
51 |
+
r = list(zip(wiki_passages[k] for k in columns))
|
52 |
+
for i in range(topk):
|
53 |
+
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
|
54 |
+
return retrieved_examples
|
55 |
+
|
56 |
+
def create_kilt_datapoint(q_id, query, answer, res_list):
|
57 |
+
# make a KILT data point
|
58 |
+
# see https://github.com/facebookresearch/KILT#kilt-data-format
|
59 |
+
|
60 |
+
provenance = [{
|
61 |
+
"wikipedia_id": r["wikipedia_id"], # *mandatory*
|
62 |
+
"title": r["title"],
|
63 |
+
"section": r["section"],
|
64 |
+
"start_paragraph_id": r["start_paragraph_id"],
|
65 |
+
"start_character": r["start_character"],
|
66 |
+
"end_paragraph_id": r["end_paragraph_id"],
|
67 |
+
"end_character": r["end_character"],
|
68 |
+
"text": r["text"],
|
69 |
+
"bleu_score": None, # wrt original evidence
|
70 |
+
"meta": None # dataset/task specific
|
71 |
+
} for r in res_list]
|
72 |
+
|
73 |
+
output = [{"answer": answer, "provenance": provenance}]
|
74 |
+
|
75 |
+
return {"id": q_id,
|
76 |
+
"input": query,
|
77 |
+
"output": output, # each element is an answer or provenance (can have multiple of each)
|
78 |
+
"meta": None # dataset/task specific
|
79 |
+
}
|
80 |
+
|
81 |
+
kilt_output = []
|
82 |
+
with open(args.kilt_input_file, "r") as f:
|
83 |
+
kilt_items = [json.loads(x) for x in f.read().strip().split("\n")]
|
84 |
+
progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document")
|
85 |
+
for idx, item in enumerate(kilt_items):
|
86 |
+
query = item["input"]
|
87 |
+
res_list = query_index(query)
|
88 |
+
|
89 |
+
res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
|
90 |
+
documents = [res["text"] for res in res_list]
|
91 |
+
conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
|
92 |
+
|
93 |
+
query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
|
94 |
+
|
95 |
+
model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")
|
96 |
+
generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device),
|
97 |
+
attention_mask=model_input["attention_mask"].to(device),
|
98 |
+
min_length=50,
|
99 |
+
max_length=250,
|
100 |
+
do_sample=False,
|
101 |
+
early_stopping=True,
|
102 |
+
num_beams=8,
|
103 |
+
temperature=1.0,
|
104 |
+
top_k=None,
|
105 |
+
top_p=None,
|
106 |
+
no_repeat_ngram_size=3,
|
107 |
+
num_return_sequences=1)
|
108 |
+
answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
|
109 |
+
clean_up_tokenization_spaces=True)
|
110 |
+
|
111 |
+
kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list)
|
112 |
+
kilt_output.append(kilt_example)
|
113 |
+
progress_bar.update(1)
|
114 |
+
|
115 |
+
with open(args.kilt_output_file, "w") as fp:
|
116 |
+
for kilt_example in kilt_output:
|
117 |
+
json.dump(kilt_example, fp)
|
118 |
+
fp.write("\n")
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
parser = argparse.ArgumentParser()
|
123 |
+
parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str)
|
124 |
+
parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str)
|
125 |
+
parser.add_argument(
|
126 |
+
"--question_encoder_name",
|
127 |
+
default="vblagoje/dpr-question_encoder-single-lfqa-base",
|
128 |
+
help="Question encoder to use",
|
129 |
+
)
|
130 |
+
|
131 |
+
parser.add_argument(
|
132 |
+
"--index_file_name",
|
133 |
+
default="../data/kilt_dpr_wikipedia_first.faiss",
|
134 |
+
help="Faiss index with passage embeddings",
|
135 |
+
)
|
136 |
+
|
137 |
+
args = parser.parse_args()
|
138 |
+
|
139 |
+
assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded"
|
140 |
+
eval_generate(args)
|
util/kilt_create_dpr_support_docs.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import faiss
|
6 |
+
import torch
|
7 |
+
from datasets import load_dataset, Dataset
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
from transformers import AutoTokenizer, DPRQuestionEncoder, DPRContextEncoder
|
10 |
+
|
11 |
+
from common import articles_to_paragraphs, embed_questions, embed_passages, create_kilt_datapoint, \
|
12 |
+
kilt_wikipedia_columns
|
13 |
+
from common import kilt_wikipedia_paragraph_columns as columns
|
14 |
+
|
15 |
+
|
16 |
+
def generate_support_docs(args):
|
17 |
+
dims = 128
|
18 |
+
min_chars_per_passage = 200
|
19 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
lfqa = load_dataset("vblagoje/lfqa")
|
21 |
+
|
22 |
+
ctx_tokenizer = AutoTokenizer.from_pretrained(args.ctx_encoder_name)
|
23 |
+
ctx_model = DPRContextEncoder.from_pretrained(args.ctx_encoder_name).to(device)
|
24 |
+
_ = ctx_model.eval()
|
25 |
+
|
26 |
+
question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
|
27 |
+
question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
|
28 |
+
_ = question_model.eval()
|
29 |
+
|
30 |
+
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
|
31 |
+
|
32 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
|
33 |
+
remove_columns=kilt_wikipedia_columns,
|
34 |
+
batch_size=512,
|
35 |
+
cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
|
36 |
+
desc="Expanding wiki articles into paragraphs")
|
37 |
+
|
38 |
+
# use paragraphs that are not simple fragments or very short sentences
|
39 |
+
# Wikipedia Faiss index needs to fit into a 16 Gb GPU
|
40 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
|
41 |
+
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
|
42 |
+
|
43 |
+
def query_index(question, topk=7):
|
44 |
+
topk = topk * 3 # grab 3x results and filter for word count
|
45 |
+
question_embedding = embed_questions(question_model, question_tokenizer, [question])
|
46 |
+
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
|
47 |
+
|
48 |
+
retrieved_examples = []
|
49 |
+
r = list(zip(wiki_passages[k] for k in columns))
|
50 |
+
for i in range(topk):
|
51 |
+
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
|
52 |
+
|
53 |
+
return retrieved_examples
|
54 |
+
|
55 |
+
def create_support_doc(dataset: Dataset, output_filename: str):
|
56 |
+
progress_bar = tqdm(range(len(dataset)), desc="Creating supporting docs")
|
57 |
+
|
58 |
+
with open(output_filename, "w") as fp:
|
59 |
+
for example in dataset:
|
60 |
+
wiki_passages = query_index(example["title"])
|
61 |
+
kilt_dp = create_kilt_datapoint(example, columns, wiki_passages)
|
62 |
+
json.dump(kilt_dp, fp)
|
63 |
+
fp.write("\n")
|
64 |
+
progress_bar.update(1)
|
65 |
+
|
66 |
+
if not os.path.isfile(args.index_file_name):
|
67 |
+
def embed_passages_for_retrieval(examples):
|
68 |
+
return embed_passages(ctx_model, ctx_tokenizer, examples, max_length=128)
|
69 |
+
|
70 |
+
paragraphs_embeddings = kilt_wikipedia_paragraphs.map(embed_passages_for_retrieval,
|
71 |
+
batched=True, batch_size=512,
|
72 |
+
cache_file_name=args.encoded_kilt_file_name,
|
73 |
+
desc="Creating faiss index")
|
74 |
+
|
75 |
+
paragraphs_embeddings.add_faiss_index(column="embeddings", custom_index=faiss.IndexFlatIP(dims))
|
76 |
+
paragraphs_embeddings.save_faiss_index("embeddings", args.index_file_name)
|
77 |
+
|
78 |
+
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
|
79 |
+
create_support_doc(lfqa["train"], "lfqa_dpr_train_precomputed_dense_docs.json")
|
80 |
+
create_support_doc(lfqa["validation"], "lfqa_dpr_validation_precomputed_dense_docs.json")
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
parser = argparse.ArgumentParser(description="Creates support docs for seq2seq model training")
|
85 |
+
parser.add_argument(
|
86 |
+
"--ctx_encoder_name",
|
87 |
+
default="vblagoje/dpr-ctx_encoder-single-lfqa-base",
|
88 |
+
help="Question encoder to use",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--question_encoder_name",
|
92 |
+
default="vblagoje/dpr-question_encoder-single-lfqa-base",
|
93 |
+
help="Question encoder to use",
|
94 |
+
)
|
95 |
+
|
96 |
+
parser.add_argument(
|
97 |
+
"--index_file_name",
|
98 |
+
default="../data/kilt_dpr_wikipedia_first.faiss",
|
99 |
+
help="Faiss index with passage embeddings",
|
100 |
+
)
|
101 |
+
|
102 |
+
parser.add_argument(
|
103 |
+
"--encoded_kilt_file_name",
|
104 |
+
default="../data/kilt_embedded.arrow",
|
105 |
+
help="Encoded KILT file name",
|
106 |
+
)
|
107 |
+
|
108 |
+
main_args, _ = parser.parse_known_args()
|
109 |
+
generate_support_docs(main_args)
|
util/query_smoke_test.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModel
|
3 |
+
|
4 |
+
from datasets import load_dataset
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained('vblagoje/retribert-base-uncased')
|
10 |
+
model = AutoModel.from_pretrained('vblagoje/retribert-base-uncased').to(device)
|
11 |
+
_ = model.eval()
|
12 |
+
|
13 |
+
index_file_name = "./data/kilt_wikipedia.faiss"
|
14 |
+
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
|
15 |
+
columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
|
16 |
+
'wikidata_info', 'history']
|
17 |
+
|
18 |
+
min_snippet_length = 20
|
19 |
+
topk = 21
|
20 |
+
|
21 |
+
def articles_to_paragraphs(examples):
|
22 |
+
ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], []
|
23 |
+
for bidx, example in enumerate(examples["text"]):
|
24 |
+
last_section = ""
|
25 |
+
for idx, p in enumerate(example["paragraph"]):
|
26 |
+
if "Section::::" in p:
|
27 |
+
last_section = p
|
28 |
+
ids.append(examples["wikipedia_id"][bidx])
|
29 |
+
titles.append(examples["wikipedia_title"][bidx])
|
30 |
+
sections.append(last_section)
|
31 |
+
texts.append(p)
|
32 |
+
start_ps.append(idx)
|
33 |
+
end_ps.append(idx)
|
34 |
+
start_cs.append(0)
|
35 |
+
end_cs.append(len(p))
|
36 |
+
|
37 |
+
return {"wikipedia_id": ids, "title": titles,
|
38 |
+
"section": sections, "text": texts,
|
39 |
+
"start_paragraph_id": start_ps, "end_paragraph_id": end_ps,
|
40 |
+
"start_character": start_cs,
|
41 |
+
"end_character": end_cs
|
42 |
+
}
|
43 |
+
|
44 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
|
45 |
+
remove_columns=columns,
|
46 |
+
batch_size=256, cache_file_name=f"./wiki_kilt_paragraphs_full.arrow",
|
47 |
+
desc="Expanding wiki articles into paragraphs")
|
48 |
+
|
49 |
+
# use paragraphs that are not simple fragments or very short sentences
|
50 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(lambda x: x["end_character"] > 250)
|
51 |
+
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", index_file_name, device=0)
|
52 |
+
|
53 |
+
def embed_questions_for_retrieval(questions):
|
54 |
+
query = tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
|
55 |
+
with torch.no_grad():
|
56 |
+
q_reps = model.embed_questions(query["input_ids"].to(device),
|
57 |
+
query["attention_mask"].to(device)).cpu().type(torch.float)
|
58 |
+
return q_reps.numpy()
|
59 |
+
|
60 |
+
def query_index(question):
|
61 |
+
question_embedding = embed_questions_for_retrieval([question])
|
62 |
+
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
|
63 |
+
columns = ['wikipedia_id', 'title', 'text', 'section', 'start_paragraph_id', 'end_paragraph_id', 'start_character','end_character']
|
64 |
+
retrieved_examples = []
|
65 |
+
r = list(zip(wiki_passages[k] for k in columns))
|
66 |
+
for i in range(topk):
|
67 |
+
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
|
68 |
+
return retrieved_examples
|
69 |
+
|
70 |
+
questions = ["What causes the contrails (cirrus aviaticus) behind jets at high altitude? ",
|
71 |
+
"Why does water heated to a room temeperature feel colder than the air around it?"]
|
72 |
+
res_list = query_index(questions[0])
|
73 |
+
res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
|
74 |
+
for res in res_list:
|
75 |
+
print("\n")
|
76 |
+
print(res)
|
77 |
+
|
78 |
+
|
79 |
+
main()
|
80 |
+
|
81 |
+
|