Spaces:
Sleeping
Sleeping
File size: 3,639 Bytes
1ee7f3c cad5609 1ee7f3c 1976bba cad5609 1ee7f3c cad5609 1ee7f3c ffa52a5 1ee7f3c cad5609 1ee7f3c cad5609 1976bba 1ee7f3c cad5609 9d40154 cad5609 1ee7f3c cad5609 1ee7f3c cad5609 1ee7f3c cad5609 9d40154 1ee7f3c cad5609 1ee7f3c f1bbe75 1ee7f3c cad5609 1ee7f3c cad5609 1ee7f3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import csv
from typing import Any
import gradio as gr
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from underthesea import word_tokenize
from retriever_trainer import PretrainedColBERT
bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert")
colbert = PretrainedColBERT(
pretrained_model_name="phamson02/colbert2.1_290000",
)
corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl")
with open("data/child_passages.tsv", "r") as f:
tsv_reader = csv.reader(f, delimiter="\t")
child_passage_ids, child_passages = zip(*[(row[0], row[1]) for row in tsv_reader])
with open("data/parent_passages.tsv", "r") as f:
tsv_reader = csv.reader(f, delimiter="\t")
parent_passages_map = {row[0]: row[1] for row in tsv_reader}
def f7(seq):
seen = set()
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]
def search(query: str, reranking: bool = False, top_k: int = 100):
query = word_tokenize(query, format="text")
print("Top 5 Answer by the NSE:")
print()
ans: list[str] = []
##### Sematic Search #####
# Encode the query using the bi-encoder and find potentially relevant passages
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
hits = hits[0] # Get the hits for the first query
top_k_child_passages = [child_passages[hit["corpus_id"]] for hit in hits]
top_k_child_passage_ids = [hit["corpus_id"] for hit in hits]
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
if reranking:
colbert_scores: list[dict[str, Any]] = colbert.rerank(
query=query, documents=top_k_child_passages, top_k=100
)
# Reorder child passage ids based on the reranking
top_k_child_passage_ids = [
top_k_child_passage_ids[score["corpus_id"]] for score in colbert_scores
]
top_20_hits = top_k_child_passage_ids[0:20]
hit_child_passage_ids = [child_passage_ids[id] for id in top_20_hits]
hit_parent_passage_ids = f7(
[
"_".join(hit_child_passage_id.split("_")[:-1])
for hit_child_passage_id in hit_child_passage_ids
]
)
assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found"
for hit in hit_parent_passage_ids[:5]:
ans.append(parent_passages_map[hit])
return ans[0], ans[1], ans[2], ans[3], ans[4]
exp = [
["Who is steve jobs?", False],
["What is coldplay?", False],
["What is a turing test?", False],
["What is the most interesting thing about our universe?", False],
["What are the most beautiful places on earth?", False],
]
desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers."
inp = gr.Textbox(lines=1, placeholder=None, label="search you query here")
reranking_checkbox = gr.Checkbox(label="Enable reranking")
out1 = gr.Textbox(type="text", label="Search result 1")
out2 = gr.Textbox(type="text", label="Search result 2")
out3 = gr.Textbox(type="text", label="Search result 3")
out4 = gr.Textbox(type="text", label="Search result 4")
out5 = gr.Textbox(type="text", label="Search result 5")
iface = gr.Interface(
fn=search,
inputs=[inp, reranking_checkbox],
outputs=[out1, out2, out3, out4, out5],
examples=exp,
article=desc,
title="Neural Search Engine",
)
iface.launch()
|