File size: 1,944 Bytes
d88c441
7f93ffe
ab1e42c
d88c441
edb69ba
02e0fb5
b9ee23e
d88c441
 
b9ee23e
 
d88c441
 
d3b95de
8650700
dc36028
edb69ba
b9ee23e
8650700
af80f71
b9ee23e
8650700
 
af80f71
de31cda
d88c441
dc36028
c2bcb93
dc36028
d88c441
dc36028
8650700
7b9a88e
dc36028
0195fc0
dc36028
80deaea
 
 
 
af80f71
0dfb940
d88c441
ac30fc4
690a920
d88c441
 
 
dc36028
d88c441
dc36028
 
 
b9ee23e
dc36028
af80f71
18eeb16
ae53336
b9ee23e
f7ba956
dc36028
af80f71
 
ae53336
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pprint
import torch
import pickle
import numpy as np
import gradio as gr
import pandas as pd
from rank_bm25 import *
from sentence_transformers import SentenceTransformer, util




# read data
df = pd.read_csv("./assets/final_combined_raw.csv")[['category', 'brand', 'product_name']].to_dict(orient='records')
doc_embeddings = np.load("./assets/multi_embed.npy", allow_pickle=True)


# Semantic Search model
semantic_model = SentenceTransformer("Abdul-Ib/paraphrase-multilingual-MiniLM-L12-v2-2024", cache_folder = "./assets")

# full-text search model
with open('./assets/bm25_L.pkl', 'rb') as bm25result_file:
    keyword_search = pickle.load(bm25result_file)


def full_text_search(normalized_query):
    tokenized_query = normalized_query.lower().split(" ")
    ft_scores = keyword_search.get_scores(tokenized_query)
    return ft_scores

def semantic_search(normalized_query):
    query_embedding = semantic_model.encode(normalized_query.lower())
    rr_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
    return rr_scores

def hybrid_search(ft_scores, rr_scores):
    ft_scores = 2 / np.pi * np.arctan(ft_scores) - 0.5
    ft_scores[ft_scores < 0] = 0
    hybrid_scores = 0.7 * ft_scores + 0.3 * rr_scores.numpy()
    return torch.topk(torch.tensor(hybrid_scores), k=10)

def print_results(hits):
    results = ""
    for score, idx in zip(hits[0], hits[1]):
        results += pprint.pformat(df[idx.numpy()], indent=4) + "\n"
    return results

def predict(query):
    normalized_query = query
    
    bm25_scores = full_text_search(normalized_query)
    sem_scores = semantic_search(normalized_query)
    hits = hybrid_search(bm25_scores, sem_scores)
    
    return print_results(hits)

app = gr.Interface(
        fn = predict,
        inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
        outputs = "text",
        title = "Hybrid Search (Lexical Search + Semantic Search)"
    )

app.launch()