phamson02 commited on
Commit
1ee7f3c
1 Parent(s): 29631f9
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/child_passages.tsv filter=lfs diff=lfs merge=lfs -text
37
+ data/parent_passages.tsv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from sentence_transformers import SentenceTransformer, util
6
+
7
+ bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert")
8
+ # cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
9
+ corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl")
10
+
11
+ with open("data/child_passages.tsv", "r") as f:
12
+ tsv_reader = csv.reader(f, delimiter="\t")
13
+ child_passage_ids = [row[0] for row in tsv_reader]
14
+
15
+ with open("data/parent_passages.tsv", "r") as f:
16
+ tsv_reader = csv.reader(f, delimiter="\t")
17
+ parent_passages = {row[0]: row[1] for row in tsv_reader}
18
+
19
+
20
+ def f7(seq):
21
+ seen = set()
22
+ seen_add = seen.add
23
+ return [x for x in seq if not (x in seen or seen_add(x))]
24
+
25
+
26
+ def search(query: str, top_k: int = 100, reranking: bool = False):
27
+ print("Top 5 Answer by the NSE:")
28
+ print()
29
+ ans: list[str] = []
30
+ ##### Sematic Search #####
31
+ # Encode the query using the bi-encoder and find potentially relevant passages
32
+ question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
33
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
34
+ hits = hits[0] # Get the hits for the first query
35
+
36
+ ##### Re-Ranking #####
37
+ # Now, score all retrieved passages with the cross_encoder
38
+ if reranking:
39
+ cross_inp = [[query, corpus[hit["corpus_id"]]] for hit in hits]
40
+ cross_scores = cross_encoder.predict(cross_inp)
41
+
42
+ # Sort results by the cross-encoder scores
43
+ for idx in range(len(cross_scores)):
44
+ hits[idx]["cross-score"] = cross_scores[idx]
45
+
46
+ hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
47
+
48
+ top_20_hits = hits[0:20]
49
+ hit_child_passage_ids = [child_passage_ids[hit["corpus_id"]] for hit in top_20_hits]
50
+ hit_parent_passage_ids = f7(
51
+ [
52
+ "_".join(hit_child_passage_id.split("_")[:-1])
53
+ for hit_child_passage_id in hit_child_passage_ids
54
+ ]
55
+ )
56
+
57
+ assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found"
58
+
59
+ for hit in hit_parent_passage_ids[:5]:
60
+ ans.append(parent_passages[hit])
61
+
62
+ return ans[0], ans[1], ans[2], ans[3], ans[4]
63
+
64
+
65
+ exp = [
66
+ "Who is steve jobs?",
67
+ "What is coldplay?",
68
+ "What is a turing test?",
69
+ "What is the most interesting thing about our universe?",
70
+ "What are the most beautiful places on earth?",
71
+ ]
72
+
73
+ desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers."
74
+
75
+ inp = gr.Textbox(lines=1, placeholder=None, label="search you query here")
76
+ out1 = gr.Textbox(type="text", label="Search result 1")
77
+ out2 = gr.Textbox(type="text", label="Search result 2")
78
+ out3 = gr.Textbox(type="text", label="Search result 3")
79
+ out4 = gr.Textbox(type="text", label="Search result 4")
80
+ out5 = gr.Textbox(type="text", label="Search result 5")
81
+
82
+ iface = gr.Interface(
83
+ fn=search,
84
+ inputs=inp,
85
+ outputs=[out1, out2, out3, out4, out5],
86
+ examples=exp,
87
+ article=desc,
88
+ title="Neural Search Engine",
89
+ )
90
+ iface.launch()
data/child_passages.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ab5beca4e38074457dd397a7b56f30679db2a92716bd31fa11a4c303db3dea3
3
+ size 428522185
data/parent_passages.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8e35cd1a742779cbfb2ff5b7809ff92663a90ee92bd72c2d8b1d66f375e6535
3
+ size 352871807
data/passage_embeds.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3f4f423e0c7ef2021f17b8ca0dd78b1a32c5c75d8eaca06238ac12e8661ea0e
3
+ size 1060860322
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ sentence-transformers
2
+ torch
3
+ pandas
4
+ gradio