Spaces:
Running
Running
import plotly.express as px | |
import streamlit as st | |
from sentence_transformers import SentenceTransformer | |
from huggingface_hub import hf_hub_url, cached_download | |
import umap.umap_ as umap | |
import pandas as pd | |
import os | |
import joblib | |
def init_models(): | |
model_name = 'sentence-transformers/all-MiniLM-L6-v2' | |
model = SentenceTransformer(model_name) | |
REPO_ID = "peter2000/umap_embed_3d_all-MiniLM-L6-v2" | |
FILENAME = "umap_embed_3d_all-MiniLM-L6-v2.sav" | |
umap_model= joblib.load(cached_download(hf_hub_url(REPO_ID, FILENAME))) | |
return model, umap_model | |
def app(): | |
word_to_embed_list = st.session_state['embed_list'] | |
cat_list = st.session_state['cat_list'] | |
with st.container(): | |
col1, col2 = st.columns(2) | |
with col1: | |
word_to_embed= st.text_input("Please enter your text here and we will embed it for you.", value="Woman",) | |
with col2: | |
cat= st.selectbox('Categorie', ('1', '2', '3', '4', '5')) | |
if st.button("Embed"): | |
with st.spinner("π Embedding your input"): | |
model, umap_model = init_models() | |
word_to_embed_list.append(word_to_embed) | |
st.session_state['embed_list'] = word_to_embed_list | |
cat_list .append(cat) | |
st.session_state['cat_list '] = cat_list | |
examples_embeddings = model.encode(word_to_embed_list) | |
examples_umap = umap_model.transform(examples_embeddings) | |
#st.write(len(examples_umap)) | |
with st.spinner("π create visualisation"): | |
fig = px.scatter_3d( | |
examples_umap[1:] , x=0, y=1, z=2, | |
color=cat_list[1:] , | |
opacity = .7, hover_data=[word_to_embed_list[1:]]) | |
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False ) | |
fig.update_traces(marker_size=4) | |
st.plotly_chart(fig) |