File size: 1,612 Bytes
a44a20c
 
050fb16
 
 
a44a20c
050fb16
f697f38
050fb16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a44a20c
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import pandas as pd
import gradio
from clean_data import text_normalizer
import pprint
from sentence_transformers import SentenceTransformer, CrossEncoder, util

# read data
df = pd.read_csv('./assets/final_combined.csv').to_dict(orient='records')
doc_embeddings = np.load('./assets/final_combined_embed.npy', allow_pickel=True)

def semantic_search(normalized_query):
    '''
    function to perform semantic search given a search query
    '''
    query_embedding = bi_encoder.encode(query)
    hits = util.semantic_search(query_embedding, doc_embeddings, top_k=50)
    return hits[0]
    
def re_ranker(normalized_query, hits):
    '''
    function to re-rank semantic search results using cross encoding
    '''
    cross_inp = [[query, doc_embeddings[hit['corpus_id']]] for hit in hit]
    cross_scores = cross_encoder.predict(cross_inp)

    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]
    return sorted(hits, key=lambda x: x['cross-score'], reverse=True)


def print_results(hits, k_items):
    results = ""
    for hit in hits[:k_items]:
        results += pprint.pformat(df[hit['corpus_id']], indent=4)
    return results
        
def predict(query):
    normalized_query = text_normalizer(query)
    
    bi_hits = semantic_search(normalized_query)
    reranked_hits = re_ranker(bi_hits)
    
    return print_results(reranked_hits, k_items = 10)

app = gr.Interface(
        fn = predict,
        inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
        outputs = "text",
        title = "Semantic Search + Re-Ranker"
    )

app.launch()