PrabakaranC's picture
Update app.py
d845ab5 verified
raw
history blame
3.41 kB
import streamlit as st
import pandas as pd
import plotly.express as px
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.manifold import TSNE
import plotly.express as px
import torch
import plotly.io as pio
pio.templates.default = "plotly"
st. set_page_config(layout="wide")
st.header("Explore the Russian Dolls :nesting_dolls: - _ :green[Nomic Embed 1.5]_",divider='violet')
st.write("matryoshka representation learning")
@st.cache_data
def get_df():
prodDf = pd.read_csv("./sample_products.csv")
return prodDf
@st.cache_resource
def get_nomicModel():
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
return model
def get_searchQueryEmbedding(query):
embeddings = model.encode(["search_query: "+query], convert_to_tensor=True)
embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
return embeddings
def get_normEmbed(query_embedding,loaded_embed,matryoshka_dim):
query_embedNorm = query_embedding[:, :matryoshka_dim]
query_embedNorm = F.normalize(query_embedNorm, p=2, dim=1)
loaded_embedNorm = loaded_embed[:, :matryoshka_dim]
loaded_embedNorm = F.normalize(loaded_embedNorm, p=2, dim=1)
return query_embedNorm,loaded_embedNorm
def insert_line_breaks(text, interval=30):
words = text.split(' ')
wrapped_text = ''
line_length = 0
for word in words:
wrapped_text += word + ' '
line_length += len(word) + 1
if line_length >= interval:
wrapped_text += '<br>'
line_length = 0
return wrapped_text.strip()
# Automatically wrap the hover text
model = get_nomicModel()
bigDollEmbedding = get_df()["Description"]
docEmbedding = torch.Tensor(np.load("./prodBigDollEmbeddings.npy"))
with st.form("my_form"):
query_input = st.text_input("query your product")
sample_products = ["a","b","c"]
submitted = st.form_submit_button("Submit")
if submitted:
queryEmbedding = get_searchQueryEmbedding(query_input)
Matry_dim = st.slider('Matryoshka Dimension', 64, 768, 64)
query_embedNorm,loaded_embedNorm = get_normEmbed(queryEmbedding,docEmbedding,Matry_dim)
similarity_scores = torch.matmul(query_embedNorm,loaded_embedNorm.T)
top_values, top_indices = torch.topk(similarity_scores, 10, dim=1)
to_index = list(top_indices.numpy()[0])
top_items_per_query = [bigDollEmbedding.tolist()[index] for index in to_index]
df = pd.DataFrame({"Product":top_items_per_query,"Score":top_values[0]})
df["Product"] = df["Product"].str.replace("search_document:","")
# st.dataframe(df)
allEmbedd = torch.concat([query_embedNorm,loaded_embedNorm])
tsne = TSNE(n_components=2, random_state=0)
projections = tsne.fit_transform(allEmbedd)
listHover = bigDollEmbedding.tolist()
listHover =[insert_line_breaks(hover_text, 30) for hover_text in listHover]
fig = px.scatter(
projections, x=0, y=1,
hover_name=[query_input]+listHover,
color=["search_query"]+(["search_document"]*270)
)
col1, col2 = st.columns([2, 2])
col2.plotly_chart(fig, use_container_width=True)
col1.dataframe(df)