rianders commited on
Commit
4a49186
1 Parent(s): 4bb5140

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import BertModel, BertTokenizer
3
+ import torch
4
+ from sklearn.decomposition import PCA
5
+ import plotly.graph_objs as go
6
+ import numpy as np
7
+
8
+ # BERT embeddings function
9
+ def get_bert_embeddings(words):
10
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
11
+ model = BertModel.from_pretrained('bert-base-uncased')
12
+ embeddings = []
13
+
14
+ # Extract embeddings
15
+ for word in words:
16
+ inputs = tokenizer(word, return_tensors='pt')
17
+ outputs = model(**inputs)
18
+ embeddings.append(outputs.last_hidden_state[0][0].detach().numpy())
19
+
20
+ # Reduce dimensions to 3 using PCA
21
+ if len(embeddings) > 0:
22
+ pca = PCA(n_components=3)
23
+ reduced_embeddings = pca.fit_transform(np.array(embeddings))
24
+ return reduced_embeddings
25
+ return []
26
+
27
+ # Plotly plotting function
28
+ def plot_interactive_bert_embeddings(embeddings, words):
29
+ if len(words) < 4:
30
+ st.error("Please provide at least 4 words/phrases for effective visualization.")
31
+ return None
32
+
33
+ data = []
34
+ for i, word in enumerate(words):
35
+ trace = go.Scatter3d(
36
+ x=[embeddings[i][0]],
37
+ y=[embeddings[i][1]],
38
+ z=[embeddings[i][2]],
39
+ mode='markers+text',
40
+ text=[word],
41
+ name=word
42
+ )
43
+ data.append(trace)
44
+
45
+ layout = go.Layout(
46
+ title='3D Scatter Plot of BERT Embeddings',
47
+ scene=dict(
48
+ xaxis=dict(title='PCA Component 1'),
49
+ yaxis=dict(title='PCA Component 2'),
50
+ zaxis=dict(title='PCA Component 3')
51
+ ),
52
+ autosize=False,
53
+ width=800,
54
+ height=600
55
+ )
56
+
57
+ fig = go.Figure(data=data, layout=layout)
58
+ return fig
59
+
60
+ def main():
61
+ st.title("BERT Embeddings Visualization")
62
+
63
+ # Initialize or get existing words list from the session state
64
+ if 'words' not in st.session_state:
65
+ st.session_state.words = []
66
+
67
+ # Text input for new words
68
+ new_words_input = st.text_input("Enter a new word/phrase:")
69
+
70
+ # Button to add new words
71
+ if st.button("Add Word/Phrase"):
72
+ if new_words_input:
73
+ st.session_state.words.append(new_words_input)
74
+ st.success(f"Added: {new_words_input}")
75
+
76
+ # Display current list of words
77
+ if st.session_state.words:
78
+ st.write("Current list of words/phrases:", ', '.join(st.session_state.words))
79
+
80
+ # Generate embeddings and plot
81
+ if st.button("Generate Embeddings"):
82
+ with st.spinner('Generating embeddings...'):
83
+ embeddings = get_bert_embeddings(st.session_state.words)
84
+ fig = plot_interactive_bert_embeddings(embeddings, st.session_state.words)
85
+ if fig is not None:
86
+ st.plotly_chart(fig, use_container_width=True)
87
+
88
+ # Reset button
89
+ if st.button("Reset"):
90
+ st.session_state.words = []
91
+
92
+ if __name__ == "__main__":
93
+ main()