File size: 1,964 Bytes
fc1945b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import streamlit as st
import pandas as pd
from safetensors import safe_open
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import pickle

st.title('Search offers')

bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

tensors = {}
with safe_open("embeddings.safetensors", framework="pt") as f : 
	for k in f.keys():
		tensors[k] = f.get_tensor(k)
corpus_embeddings = tensors['embedding']

with open('corpus.pickle', 'rb') as f:
	passages = pickle.load(f)


def search(query, top_k):

	query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
	query_embedding = query_embedding #.cuda()
	hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
	hits = hits[0]

	cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
	cross_scores = cross_encoder.predict(cross_inp)

	# Sort results by the cross-encoder scores
	for idx in range(len(cross_scores)):
	    hits[idx]['cross-score'] = cross_scores[idx]

	hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
	score_list, output_list = [],[]
	for hit in hits[:10]:
	  score_list.append("{:.3f}".format(hit['cross-score']))
	  temp_output = passages[hit['corpus_id']].replace("\n", " ")
	  temp_output = list(temp_output.rsplit('{'))[0].strip()
	  output_list.append(temp_output)

	dataframe = pd.DataFrame({'score': score_list, 'offers': output_list})
	dataframe.drop_duplicates(subset=['offers'], keep='first', inplace=True)
	return dataframe
	
	
with st.form("my_form"):
	query = st.text_input("Enter the brand name, category or  retailer name to search \
							for relevant offers 👇",
							placeholder = "Enter the text here")
	num = st.number_input('Manximum number of offers to display', min_value=1, max_value=10)
	
	submitted = st.form_submit_button("Submit")
	if submitted:
		df = search(query, num)
		st.dataframe(df, use_container_width=True)