Spaces:
Runtime error
Runtime error
File size: 1,612 Bytes
a44a20c 050fb16 a44a20c 050fb16 f697f38 050fb16 a44a20c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import pandas as pd
import gradio
from clean_data import text_normalizer
import pprint
from sentence_transformers import SentenceTransformer, CrossEncoder, util
# read data
df = pd.read_csv('./assets/final_combined.csv').to_dict(orient='records')
doc_embeddings = np.load('./assets/final_combined_embed.npy', allow_pickel=True)
def semantic_search(normalized_query):
'''
function to perform semantic search given a search query
'''
query_embedding = bi_encoder.encode(query)
hits = util.semantic_search(query_embedding, doc_embeddings, top_k=50)
return hits[0]
def re_ranker(normalized_query, hits):
'''
function to re-rank semantic search results using cross encoding
'''
cross_inp = [[query, doc_embeddings[hit['corpus_id']]] for hit in hit]
cross_scores = cross_encoder.predict(cross_inp)
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
return sorted(hits, key=lambda x: x['cross-score'], reverse=True)
def print_results(hits, k_items):
results = ""
for hit in hits[:k_items]:
results += pprint.pformat(df[hit['corpus_id']], indent=4)
return results
def predict(query):
normalized_query = text_normalizer(query)
bi_hits = semantic_search(normalized_query)
reranked_hits = re_ranker(bi_hits)
return print_results(reranked_hits, k_items = 10)
app = gr.Interface(
fn = predict,
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
outputs = "text",
title = "Semantic Search + Re-Ranker"
)
app.launch() |