rishikesh commited on
Commit
fc1945b
β€’
1 Parent(s): 10b2985

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +58 -0
  2. corpus.pickle +3 -0
  3. embeddings.safetensors +3 -0
  4. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from safetensors import safe_open
4
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
5
+ import pickle
6
+
7
+ st.title('Search offers')
8
+
9
+ bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
10
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
11
+
12
+ tensors = {}
13
+ with safe_open("embeddings.safetensors", framework="pt") as f :
14
+ for k in f.keys():
15
+ tensors[k] = f.get_tensor(k)
16
+ corpus_embeddings = tensors['embedding']
17
+
18
+ with open('corpus.pickle', 'rb') as f:
19
+ passages = pickle.load(f)
20
+
21
+
22
+ def search(query, top_k):
23
+
24
+ query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
25
+ query_embedding = query_embedding #.cuda()
26
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
27
+ hits = hits[0]
28
+
29
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
30
+ cross_scores = cross_encoder.predict(cross_inp)
31
+
32
+ # Sort results by the cross-encoder scores
33
+ for idx in range(len(cross_scores)):
34
+ hits[idx]['cross-score'] = cross_scores[idx]
35
+
36
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
37
+ score_list, output_list = [],[]
38
+ for hit in hits[:10]:
39
+ score_list.append("{:.3f}".format(hit['cross-score']))
40
+ temp_output = passages[hit['corpus_id']].replace("\n", " ")
41
+ temp_output = list(temp_output.rsplit('{'))[0].strip()
42
+ output_list.append(temp_output)
43
+
44
+ dataframe = pd.DataFrame({'score': score_list, 'offers': output_list})
45
+ dataframe.drop_duplicates(subset=['offers'], keep='first', inplace=True)
46
+ return dataframe
47
+
48
+
49
+ with st.form("my_form"):
50
+ query = st.text_input("Enter the brand name, category or retailer name to search \
51
+ for relevant offers πŸ‘‡",
52
+ placeholder = "Enter the text here")
53
+ num = st.number_input('Manximum number of offers to display', min_value=1, max_value=10)
54
+
55
+ submitted = st.form_submit_button("Submit")
56
+ if submitted:
57
+ df = search(query, num)
58
+ st.dataframe(df, use_container_width=True)
corpus.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5347f3cfe8e90681a71e27f3a77bbb906e1c05eb9f5cbe9c1f28eeafecdd94ab
3
+ size 72475
embeddings.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17f262c76dd1865aac7f4c486d7c951e1f0adfb7bc5e044b13268e67ead5a494
3
+ size 1274968
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ sentence-transformers==2.2.2
2
+ safetensors==0.3.3
3
+ pandas==1.5.3
4
+ pickle==4.0