Spaces:
Runtime error
Runtime error
import pandas as pd | |
import gradio | |
from clean_data import text_normalizer | |
import pprint | |
from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
# read data | |
df = pd.read_csv('./assets/final_combined.csv').to_dict(orient='records') | |
doc_embeddings = np.load('./assets/final_combined_embed.npy', allow_pickel=True) | |
def semantic_search(normalized_query): | |
''' | |
function to perform semantic search given a search query | |
''' | |
query_embedding = bi_encoder.encode(query) | |
hits = util.semantic_search(query_embedding, doc_embeddings, top_k=50) | |
return hits[0] | |
def re_ranker(normalized_query, hits): | |
''' | |
function to re-rank semantic search results using cross encoding | |
''' | |
cross_inp = [[query, doc_embeddings[hit['corpus_id']]] for hit in hit] | |
cross_scores = cross_encoder.predict(cross_inp) | |
for idx in range(len(cross_scores)): | |
hits[idx]['cross-score'] = cross_scores[idx] | |
return sorted(hits, key=lambda x: x['cross-score'], reverse=True) | |
def print_results(hits, k_items): | |
results = "" | |
for hit in hits[:k_items]: | |
results += pprint.pformat(df[hit['corpus_id']], indent=4) | |
return results | |
def predict(query): | |
normalized_query = text_normalizer(query) | |
bi_hits = semantic_search(normalized_query) | |
reranked_hits = re_ranker(bi_hits) | |
return print_results(reranked_hits, k_items = 10) | |
app = gr.Interface( | |
fn = predict, | |
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."), | |
outputs = "text", | |
title = "Semantic Search + Re-Ranker" | |
) | |
app.launch() |