Spaces:
Running
Running
File size: 1,925 Bytes
d88c441 7f93ffe ab1e42c d88c441 edb69ba 02e0fb5 b9ee23e d88c441 b9ee23e d88c441 d3b95de 6020ca3 dc36028 edb69ba b9ee23e 6020ca3 af80f71 b9ee23e 8650700 af80f71 de31cda d88c441 dc36028 c2bcb93 dc36028 d88c441 dc36028 8650700 7b9a88e dc36028 0195fc0 dc36028 80deaea af80f71 0dfb940 d88c441 ac30fc4 690a920 d88c441 dc36028 d88c441 dc36028 b9ee23e dc36028 af80f71 18eeb16 ae53336 b9ee23e f7ba956 dc36028 af80f71 ae53336 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import pprint
import torch
import pickle
import numpy as np
import gradio as gr
import pandas as pd
from rank_bm25 import *
from sentence_transformers import SentenceTransformer, util
# read data
df = pd.read_csv("./assets/final_combined_raw.csv")[['category', 'brand', 'product_name']].to_dict(orient='records')
doc_embeddings = np.load("./assets/e5_embed.npy", allow_pickle=True)
# Semantic Search model
semantic_model = SentenceTransformer("Abdul-Ib/multilingual-e5-small-2024", cache_folder = "./assets")
# full-text search model
with open('./assets/bm25_L.pkl', 'rb') as bm25result_file:
keyword_search = pickle.load(bm25result_file)
def full_text_search(normalized_query):
tokenized_query = normalized_query.lower().split(" ")
ft_scores = keyword_search.get_scores(tokenized_query)
return ft_scores
def semantic_search(normalized_query):
query_embedding = semantic_model.encode(normalized_query.lower())
rr_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
return rr_scores
def hybrid_search(ft_scores, rr_scores):
ft_scores = 2 / np.pi * np.arctan(ft_scores) - 0.5
ft_scores[ft_scores < 0] = 0
hybrid_scores = 0.7 * ft_scores + 0.3 * rr_scores.numpy()
return torch.topk(torch.tensor(hybrid_scores), k=10)
def print_results(hits):
results = ""
for score, idx in zip(hits[0], hits[1]):
results += pprint.pformat(df[idx.numpy()], indent=4) + "\n"
return results
def predict(query):
normalized_query = query
bm25_scores = full_text_search(normalized_query)
sem_scores = semantic_search(normalized_query)
hits = hybrid_search(bm25_scores, sem_scores)
return print_results(hits)
app = gr.Interface(
fn = predict,
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
outputs = "text",
title = "Hybrid Search (Lexical Search + Semantic Search)"
)
app.launch() |