Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import torch.nn.functional as F | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
from sklearn.manifold import TSNE | |
import plotly.express as px | |
import torch | |
import plotly.io as pio | |
pio.templates.default = "plotly" | |
st. set_page_config(layout="wide") | |
st.header("Explore the Russian Dolls :nesting_dolls: - _ :green[Nomic Embed 1.5]_",divider='violet') | |
st.write("matryoshka representation learning") | |
def get_df(): | |
prodDf = pd.read_csv("./sample_products.csv") | |
return prodDf | |
def get_nomicModel(): | |
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True) | |
return model | |
def get_searchQueryEmbedding(query): | |
embeddings = model.encode(["search_query: "+query], convert_to_tensor=True) | |
embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],)) | |
return embeddings | |
def get_normEmbed(query_embedding,loaded_embed,matryoshka_dim): | |
query_embedNorm = query_embedding[:, :matryoshka_dim] | |
query_embedNorm = F.normalize(query_embedNorm, p=2, dim=1) | |
loaded_embedNorm = loaded_embed[:, :matryoshka_dim] | |
loaded_embedNorm = F.normalize(loaded_embedNorm, p=2, dim=1) | |
return query_embedNorm,loaded_embedNorm | |
def insert_line_breaks(text, interval=30): | |
words = text.split(' ') | |
wrapped_text = '' | |
line_length = 0 | |
for word in words: | |
wrapped_text += word + ' ' | |
line_length += len(word) + 1 | |
if line_length >= interval: | |
wrapped_text += '<br>' | |
line_length = 0 | |
return wrapped_text.strip() | |
# Automatically wrap the hover text | |
model = get_nomicModel() | |
bigDollEmbedding = get_df()["Description"] | |
docEmbedding = torch.Tensor(np.load("./prodBigDollEmbeddings.npy")) | |
with st.form("my_form"): | |
query_input = st.text_input("query your product") | |
sample_products = ["a","b","c"] | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
queryEmbedding = get_searchQueryEmbedding(query_input) | |
Matry_dim = st.slider('Matryoshka Dimension', 64, 768, 64) | |
query_embedNorm,loaded_embedNorm = get_normEmbed(queryEmbedding,docEmbedding,Matry_dim) | |
similarity_scores = torch.matmul(query_embedNorm,loaded_embedNorm.T) | |
top_values, top_indices = torch.topk(similarity_scores, 10, dim=1) | |
to_index = list(top_indices.numpy()[0]) | |
top_items_per_query = [bigDollEmbedding.tolist()[index] for index in to_index] | |
df = pd.DataFrame({"Product":top_items_per_query,"Score":top_values[0]}) | |
df["Product"] = df["Product"].str.replace("search_document:","") | |
# st.dataframe(df) | |
allEmbedd = torch.concat([query_embedNorm,loaded_embedNorm]) | |
tsne = TSNE(n_components=2, random_state=0) | |
projections = tsne.fit_transform(allEmbedd) | |
listHover = bigDollEmbedding.tolist() | |
listHover =[insert_line_breaks(hover_text, 30) for hover_text in listHover] | |
fig = px.scatter( | |
projections, x=0, y=1, | |
hover_name=[query_input]+listHover, | |
color=["search_query"]+(["search_document"]*270) | |
) | |
col1, col2 = st.columns([2, 2]) | |
col2.plotly_chart(fig, use_container_width=True) | |
col1.dataframe(df) | |