Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
from safetensors import safe_open | |
from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
import pickle | |
st.title('Search offers in Fetch app') | |
st.markdown("""Fetch Rewards is a mobile app where you can earn free gift cards by scanning and uploading your shopping receipts. | |
You accumulate points for eligible receipts, which can be redeemed for various gift cards. It's a way to get rewards for your | |
everyday shopping.""") | |
st.markdown(""" | |
If you type in a category (ex.diapers), this search engine will return | |
a list of offers relevant to this category. You can also search using | |
brand name (ex. Huggies) or a retailer name (ex.Target). This tool will | |
return relevant offers related to that category, brand or retailer along | |
with the similarity score representing how similar the result offer is to | |
your search query.""") | |
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
tensors = {} | |
with safe_open("embeddings.safetensors", framework="pt") as f : | |
for k in f.keys(): | |
tensors[k] = f.get_tensor(k) | |
corpus_embeddings = tensors['embedding'] | |
with open('corpus.pickle', 'rb') as f: | |
passages = pickle.load(f) | |
def search(query, top_k): | |
query_embedding = bi_encoder.encode(query, convert_to_tensor=True) | |
query_embedding = query_embedding #.cuda() | |
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k) | |
hits = hits[0] | |
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] | |
cross_scores = cross_encoder.predict(cross_inp) | |
# Sort results by the cross-encoder scores | |
for idx in range(len(cross_scores)): | |
hits[idx]['cross-score'] = cross_scores[idx] | |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) | |
score_list, output_list = [],[] | |
for hit in hits[:10]: | |
score_list.append("{:.3f}".format(hit['cross-score'])) | |
temp_output = passages[hit['corpus_id']].replace("\n", " ") | |
temp_output = list(temp_output.rsplit('{'))[0].strip() | |
output_list.append(temp_output) | |
dataframe = pd.DataFrame({'score': score_list, 'offers': output_list}) | |
dataframe.drop_duplicates(subset=['offers'], keep='first', inplace=True) | |
return dataframe | |
with st.form("my_form"): | |
query = st.text_input("Enter the brand name, category or retailer name to search \ | |
for relevant offers π", | |
placeholder = "Enter the text here") | |
num = st.number_input('Manximum number of offers to display', min_value=1, max_value=10) | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
df = search(query, num) | |
st.dataframe(df, use_container_width=True) |