Spaces:
Sleeping
Sleeping
# streamlit_app.py | |
import streamlit as st | |
import pandas as pd | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
import pickle | |
# Load the first set of sentences & embeddings from disk | |
with open('clinical_inno_embeddings_masterid_paraphrase-multilingual-mpnet-base-v2.pkl', "rb") as fIn: | |
stored_data_1 = pickle.load(fIn) | |
stored_masterid_1 = stored_data_1['pro_master_id'] | |
stored_products_1 = stored_data_1['products'] | |
stored_embeddings_1 = stored_data_1['embeddings'] | |
# Load the second set of sentences & embeddings from disk | |
# Replace 'other_embeddings.pkl' with your actual second embeddings file | |
with open('mean_clinical_inno_embeddings_masterid_paraphrase-multilingual-mpnet-base-v2.pkl', "rb") as fIn: | |
stored_data_2 = pickle.load(fIn) | |
stored_masterid_2 = stored_data_2['pro_master_id'] | |
stored_products_2 = stored_data_2['mean_products'] | |
stored_embeddings_2 = stored_data_2['mean_embeddings'] | |
# Initialize the SentenceTransformer model | |
embedder = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') | |
def get_similar_products(query, products, embeddings, top_k=10): | |
query_embedding = embedder.encode(query, convert_to_tensor=True) | |
cos_scores = util.cos_sim(query_embedding, embeddings)[0] | |
top_results = torch.topk(cos_scores, k=top_k) | |
similar_products = [(products[idx.item()], score.item()) for score, idx in zip(top_results[0], top_results[1])] | |
return similar_products | |
# Streamlit UI | |
st.title("Product Similarity Finder") | |
# Embedding selection slider | |
embedding_option = st.select_slider( | |
'Select Search Approach', | |
options=['All Products', 'Master Products'] | |
) | |
# Determine which embeddings to use based on the slider selection | |
if embedding_option == 'All Products': | |
stored_products = stored_products_1 | |
st.write(len(stored_products)) | |
stored_embeddings = stored_embeddings_1 | |
else: | |
stored_products = stored_products_2 | |
st.write(len(stored_products)) | |
stored_embeddings = stored_embeddings_2 | |
# User input | |
user_input = st.text_input("Enter a product name or description:") | |
# Search button | |
if st.button('Search'): | |
if user_input: | |
# Get and display similar products | |
results = get_similar_products(user_input, stored_products, stored_embeddings) | |
# Convert results to a DataFrame for nicer display | |
results_df = pd.DataFrame(results, columns=['Product', 'Score']) | |
# Use Streamlit's dataframe function to display results in a table with default formatting | |
st.dataframe(results_df.style.format({'Score': '{:.4f}'})) | |
else: | |
st.write("Please enter a product name or description to search.") | |