theachyuttiwari commited on
Commit
8a678af
·
1 Parent(s): 53a7a2f

Upload create_faiss_index.py

Browse files
Files changed (1) hide show
  1. create_faiss_index.py +67 -0
create_faiss_index.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import faiss
5
+ import torch
6
+ from datasets import load_dataset
7
+ from transformers import AutoTokenizer, DPRContextEncoder
8
+
9
+ from common import articles_to_paragraphs, embed_passages
10
+
11
+
12
+ def create_faiss(args):
13
+ dims = 128
14
+ min_chars_per_passage = 200
15
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ ctx_tokenizer = AutoTokenizer.from_pretrained(args.ctx_encoder_name)
18
+ ctx_model = DPRContextEncoder.from_pretrained(args.ctx_encoder_name).to(device)
19
+ _ = ctx_model.eval()
20
+
21
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
22
+ kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
23
+ 'wikidata_info', 'history']
24
+
25
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
26
+ remove_columns=kilt_wikipedia_columns,
27
+ batch_size=512,
28
+ cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
29
+ desc="Expanding wiki articles into paragraphs")
30
+
31
+ # use paragraphs that are not simple fragments or very short sentences
32
+ # Wikipedia Faiss index needs to fit into a 16 Gb GPU
33
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
34
+ lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
35
+
36
+ if not os.path.isfile(args.index_file_name):
37
+ def embed_passages_for_retrieval(examples):
38
+ return embed_passages(ctx_model, ctx_tokenizer, examples, max_length=128)
39
+
40
+ paragraphs_embeddings = kilt_wikipedia_paragraphs.map(embed_passages_for_retrieval,
41
+ batched=True, batch_size=512,
42
+ cache_file_name="../data/kilt_embedded.arrow",
43
+ desc="Creating faiss index")
44
+
45
+ paragraphs_embeddings.add_faiss_index(column="embeddings", custom_index=faiss.IndexFlatIP(dims))
46
+ paragraphs_embeddings.save_faiss_index("embeddings", args.index_file_name)
47
+ else:
48
+ print(f"Faiss index already exists {args.index_file_name}")
49
+
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file")
53
+
54
+ parser.add_argument(
55
+ "--ctx_encoder_name",
56
+ default="vblagoje/dpr-ctx_encoder-single-lfqa-base",
57
+ help="Encoding model to use for passage encoding",
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--index_file_name",
62
+ default="../data/kilt_dpr_wikipedia.faiss",
63
+ help="Faiss index file with passage embeddings",
64
+ )
65
+
66
+ main_args, _ = parser.parse_known_args()
67
+ create_faiss(main_args)