File size: 3,639 Bytes
1ee7f3c
cad5609
1ee7f3c
 
 
 
1976bba
cad5609
 
1ee7f3c
 
cad5609
 
 
1ee7f3c
 
 
 
ffa52a5
1ee7f3c
 
 
cad5609
1ee7f3c
 
 
 
 
 
 
 
cad5609
1976bba
 
1ee7f3c
 
 
 
 
 
 
 
 
cad5609
9d40154
cad5609
1ee7f3c
 
 
cad5609
 
 
1ee7f3c
cad5609
 
 
 
1ee7f3c
cad5609
9d40154
 
 
1ee7f3c
 
 
 
 
 
 
 
 
 
cad5609
1ee7f3c
 
 
 
 
f1bbe75
 
 
 
 
1ee7f3c
 
 
 
 
cad5609
 
1ee7f3c
 
 
 
 
 
 
 
cad5609
1ee7f3c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import csv
from typing import Any

import gradio as gr
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from underthesea import word_tokenize
from retriever_trainer import PretrainedColBERT


bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert")
colbert = PretrainedColBERT(
    pretrained_model_name="phamson02/colbert2.1_290000",
)
corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl")

with open("data/child_passages.tsv", "r") as f:
    tsv_reader = csv.reader(f, delimiter="\t")
    child_passage_ids, child_passages = zip(*[(row[0], row[1]) for row in tsv_reader])

with open("data/parent_passages.tsv", "r") as f:
    tsv_reader = csv.reader(f, delimiter="\t")
    parent_passages_map = {row[0]: row[1] for row in tsv_reader}


def f7(seq):
    seen = set()
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]


def search(query: str, reranking: bool = False, top_k: int = 100):
    query = word_tokenize(query, format="text")

    print("Top 5 Answer by the NSE:")
    print()
    ans: list[str] = []
    ##### Sematic Search #####
    # Encode the query using the bi-encoder and find potentially relevant passages
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query

    top_k_child_passages = [child_passages[hit["corpus_id"]] for hit in hits]
    top_k_child_passage_ids = [hit["corpus_id"] for hit in hits]

    ##### Re-Ranking #####
    # Now, score all retrieved passages with the cross_encoder
    if reranking:
        colbert_scores: list[dict[str, Any]] = colbert.rerank(
            query=query, documents=top_k_child_passages, top_k=100
        )

        # Reorder child passage ids based on the reranking
        top_k_child_passage_ids = [
            top_k_child_passage_ids[score["corpus_id"]] for score in colbert_scores
        ]

    top_20_hits = top_k_child_passage_ids[0:20]

    hit_child_passage_ids = [child_passage_ids[id] for id in top_20_hits]

    hit_parent_passage_ids = f7(
        [
            "_".join(hit_child_passage_id.split("_")[:-1])
            for hit_child_passage_id in hit_child_passage_ids
        ]
    )

    assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found"

    for hit in hit_parent_passage_ids[:5]:
        ans.append(parent_passages_map[hit])

    return ans[0], ans[1], ans[2], ans[3], ans[4]


exp = [
    ["Who is steve jobs?", False],
    ["What is coldplay?", False],
    ["What is a turing test?", False],
    ["What is the most interesting thing about our universe?", False],
    ["What are the most beautiful places on earth?", False],
]

desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers."

inp = gr.Textbox(lines=1, placeholder=None, label="search you query here")
reranking_checkbox = gr.Checkbox(label="Enable reranking")

out1 = gr.Textbox(type="text", label="Search result 1")
out2 = gr.Textbox(type="text", label="Search result 2")
out3 = gr.Textbox(type="text", label="Search result 3")
out4 = gr.Textbox(type="text", label="Search result 4")
out5 = gr.Textbox(type="text", label="Search result 5")

iface = gr.Interface(
    fn=search,
    inputs=[inp, reranking_checkbox],
    outputs=[out1, out2, out3, out4, out5],
    examples=exp,
    article=desc,
    title="Neural Search Engine",
)
iface.launch()