mohitmayank commited on
Commit
384d3d8
1 Parent(s): 2574e04

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A Streamlit application to visualize sentence embeddings
3
+ Author: Mohit Mayank
4
+ Contact: mohitmayank1@gmail.com
5
+ """
6
+
7
+ ## Import
8
+ ## ----------------
9
+ # data
10
+ import pandas as pd
11
+ # model
12
+ from sentence_transformers import SentenceTransformer, util
13
+ # viz
14
+ import streamlit as st
15
+ import plotly.express as px
16
+ # DR
17
+ from sklearn.decomposition import PCA
18
+ from sklearn.manifold import TSNE
19
+
20
+ ## Init
21
+ ## ----------------
22
+ # set config
23
+ # st.set_page_config(layout="wide", page_title="SentenceViz 🕵")
24
+ st.markdown("# SentenceViz")
25
+ st.markdown("A Streamlit application to visulize sentence embeddings")
26
+
27
+ # load the summarization model (cache for faster loading)
28
+ @st.cache(allow_output_mutation=True)
29
+ def load_similarity_model(model_name='all-MiniLM-L6-v2'):
30
+ model = SentenceTransformer(model_name)
31
+ return model
32
+
33
+ @st.cache(allow_output_mutation=True)
34
+ def perform_embedding(df, text_col_name):
35
+ embeddings = model.encode(df[text_col_name])
36
+ return embeddings
37
+
38
+ # gloabl vars
39
+ df = None
40
+ model = None
41
+ embeddings = None
42
+
43
+ ## Design Sidebar
44
+ ## -----------------
45
+ ## Data
46
+ st.sidebar.markdown("## Data")
47
+ uploaded_file = st.sidebar.file_uploader("Upload a CSV file with sentences (we remove NaN)")
48
+ if uploaded_file is not None:
49
+ progress = st.empty()
50
+ progress.text("Reading file...")
51
+ df = pd.read_csv(uploaded_file).dropna().reset_index(drop=True)
52
+ progress.text(f"Reading file...Done! Size: {df.shape[0]}")
53
+
54
+ ## Embedding
55
+ st.sidebar.markdown("## Embedding")
56
+ supported_models = ['all-MiniLM-L6-v2', 'paraphrase-albert-small-v2', 'paraphrase-MiniLM-L3-v2', 'all-distilroberta-v1', 'all-mpnet-base-v2']
57
+ selected_model_option = st.sidebar.selectbox("Select Model:", supported_models)
58
+ text_col_name = st.sidebar.text_input("Text column to embed")
59
+ if len(text_col_name) > 0 and df is not None:
60
+ print("text_col_name -->", text_col_name)
61
+ df[text_col_name] = df[text_col_name].str.wrap(30)
62
+ df[text_col_name] = df[text_col_name].apply(lambda x: x.replace('\n', '<br>'))
63
+ progress = st.empty()
64
+ progress.text("Creating embedding...")
65
+ model = load_similarity_model(selected_model_option)
66
+ embeddings = perform_embedding(df, text_col_name)
67
+ progress.text("Creating embedding...Done!")
68
+
69
+ ## Visualization
70
+ st.sidebar.markdown("## Visualization")
71
+ dr_algo = st.sidebar.selectbox("Dimensionality Reduction Algorithm", ('PCA', 't-SNE'))
72
+ color_col = st.sidebar.text_input("Color using this col")
73
+ if len(color_col.strip()) == 0:
74
+ color_col = None
75
+
76
+ if st.sidebar.button('Plot!'):
77
+ # get the embeddings and perform DR
78
+ if dr_algo == 'PCA':
79
+ pca = PCA(n_components=2)
80
+ reduced_embeddings = pca.fit_transform(embeddings)
81
+ elif dr_algo == 't-SNE':
82
+ tsne = TSNE(n_components=2)
83
+ reduced_embeddings = tsne.fit_transform(embeddings)
84
+
85
+ # modify the df
86
+ # df['complete_embeddings'] = embeddings
87
+ df['viz_embeddings_x'] = reduced_embeddings[:, 0]
88
+ df['viz_embeddings_y'] = reduced_embeddings[:, 1]
89
+
90
+ # plot the data
91
+ fig = px.scatter(df, x='viz_embeddings_x', y='viz_embeddings_y',
92
+ title=f'"{dr_algo}" on {df.shape[0]} "{selected_model_option}" embeddings',
93
+ color=color_col, hover_data=[text_col_name])
94
+ fig.update_layout(yaxis={'visible': False, 'showticklabels': False})
95
+ fig.update_layout(xaxis={'visible': False, 'showticklabels': False})
96
+ fig.update_traces(marker=dict(size=10, opacity=0.7, line=dict(width=1,color='DarkSlateGrey')),selector=dict(mode='markers'))
97
+ st.plotly_chart(fig, use_container_width=True)