import pandas as pd import gradio from clean_data import text_normalizer import pprint from sentence_transformers import SentenceTransformer, CrossEncoder, util # read data df = pd.read_csv('./assets/final_combined.csv').to_dict(orient='records') doc_embeddings = np.load('./assets/final_combined_embed.npy', allow_pickel=True) def semantic_search(normalized_query): ''' function to perform semantic search given a search query ''' query_embedding = bi_encoder.encode(query) hits = util.semantic_search(query_embedding, doc_embeddings, top_k=50) return hits[0] def re_ranker(normalized_query, hits): ''' function to re-rank semantic search results using cross encoding ''' cross_inp = [[query, doc_embeddings[hit['corpus_id']]] for hit in hit] cross_scores = cross_encoder.predict(cross_inp) for idx in range(len(cross_scores)): hits[idx]['cross-score'] = cross_scores[idx] return sorted(hits, key=lambda x: x['cross-score'], reverse=True) def print_results(hits, k_items): results = "" for hit in hits[:k_items]: results += pprint.pformat(df[hit['corpus_id']], indent=4) return results def predict(query): normalized_query = text_normalizer(query) bi_hits = semantic_search(normalized_query) reranked_hits = re_ranker(bi_hits) return print_results(reranked_hits, k_items = 10) app = gr.Interface( fn = predict, inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."), outputs = "text", title = "Semantic Search + Re-Ranker" ) app.launch()