Abdul-Ib's picture
Update app.py
f697f38 verified
raw
history blame
1.61 kB
import pandas as pd
import gradio
from clean_data import text_normalizer
import pprint
from sentence_transformers import SentenceTransformer, CrossEncoder, util
# read data
df = pd.read_csv('./assets/final_combined.csv').to_dict(orient='records')
doc_embeddings = np.load('./assets/final_combined_embed.npy', allow_pickel=True)
def semantic_search(normalized_query):
'''
function to perform semantic search given a search query
'''
query_embedding = bi_encoder.encode(query)
hits = util.semantic_search(query_embedding, doc_embeddings, top_k=50)
return hits[0]
def re_ranker(normalized_query, hits):
'''
function to re-rank semantic search results using cross encoding
'''
cross_inp = [[query, doc_embeddings[hit['corpus_id']]] for hit in hit]
cross_scores = cross_encoder.predict(cross_inp)
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
return sorted(hits, key=lambda x: x['cross-score'], reverse=True)
def print_results(hits, k_items):
results = ""
for hit in hits[:k_items]:
results += pprint.pformat(df[hit['corpus_id']], indent=4)
return results
def predict(query):
normalized_query = text_normalizer(query)
bi_hits = semantic_search(normalized_query)
reranked_hits = re_ranker(bi_hits)
return print_results(reranked_hits, k_items = 10)
app = gr.Interface(
fn = predict,
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
outputs = "text",
title = "Semantic Search + Re-Ranker"
)
app.launch()