rianders commited on
Commit
91c7a65
·
verified ·
1 Parent(s): 78f2519

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -37
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import streamlit as st
 
2
  from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
3
  from sklearn.decomposition import PCA
4
  import plotly.graph_objs as go
@@ -31,38 +32,17 @@ def plot_interactive_embeddings(embeddings, words):
31
 
32
  if len(words) == 2:
33
  fig = go.Figure(data=[
34
- go.Scatter(
35
- x=[emb[0]],
36
- y=[emb[1]],
37
- mode='markers+text',
38
- text=[word],
39
- name=word
40
- ) for emb, word in zip(reduced_embeddings, words)
41
  ])
42
- fig.update_layout(
43
- title='2D Scatter Plot of Embeddings',
44
- xaxis_title='PCA Component 1',
45
- yaxis_title='PCA Component 2'
46
- )
47
  else:
48
  fig = go.Figure(data=[
49
- go.Scatter3d(
50
- x=[emb[0]],
51
- y=[emb[1]],
52
- z=[emb[2]],
53
- mode='markers+text',
54
- text=[word],
55
- name=word
56
- ) for emb, word in zip(reduced_embeddings, words)
57
  ])
58
- fig.update_layout(
59
- title='3D Scatter Plot of Embeddings',
60
- scene=dict(
61
- xaxis_title='PCA Component 1',
62
- yaxis_title='PCA Component 2',
63
- zaxis_title='PCA Component 3'
64
- )
65
- )
66
 
67
  fig.update_layout(autosize=False, width=800, height=600)
68
  st.plotly_chart(fig, use_container_width=True)
@@ -72,33 +52,50 @@ def plot_interactive_embeddings(embeddings, words):
72
  def main():
73
  st.title("Language Model Embeddings Visualization")
74
 
 
 
 
 
 
 
 
 
 
 
75
  model_choice = st.selectbox("Choose a model:", ["BERT", "RoBERTa"])
76
  tokenizer, model = load_model(model_choice)
77
 
78
  default_word = "example"
79
- if "words" not in st.session_state or "model" not in st.session_state:
80
  st.session_state.words = [default_word]
81
- st.session_state.model = model_choice
82
  init_db()
83
  embedding = get_embeddings([default_word], tokenizer, model)[0]
84
  save_embeddings_to_db(default_word, embedding)
85
- elif st.session_state.model != model_choice:
86
- st.session_state.words = [default_word]
87
- st.session_state.model = model_choice
88
- clear_all_entries()
89
- embedding = get_embeddings([default_word], tokenizer, model)[0]
90
- save_embeddings_to_db(default_word, embedding)
91
 
92
  st.write(f"Current words ({model_choice}):", ", ".join(st.session_state.words))
93
 
94
  new_word = st.text_input("Enter a new word or phrase:", "")
95
  if st.button("Add Word/Phrase"):
96
- if new_word:
97
  embedding = get_embeddings([new_word], tokenizer, model)[0]
98
  save_embeddings_to_db(new_word, embedding)
99
  st.session_state.words.append(new_word)
100
  st.experimental_rerun()
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  if st.button("Clear All Entries"):
103
  clear_all_entries()
104
  st.session_state.words = [default_word]
 
1
  import streamlit as st
2
+ import pandas as pd
3
  from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
 
32
 
33
  if len(words) == 2:
34
  fig = go.Figure(data=[
35
+ go.Scatter(x=[emb[0]], y=[emb[1]], mode='markers+text', text=[word], name=word)
36
+ for emb, word in zip(reduced_embeddings, words)
 
 
 
 
 
37
  ])
38
+ fig.update_layout(title='2D Scatter Plot of Embeddings', xaxis_title='PCA Component 1', yaxis_title='PCA Component 2')
 
 
 
 
39
  else:
40
  fig = go.Figure(data=[
41
+ go.Scatter3d(x=[emb[0]], y=[emb[1]], z=[emb[2]], mode='markers+text', text=[word], name=word)
42
+ for emb, word in zip(reduced_embeddings, words)
 
 
 
 
 
 
43
  ])
44
+ fig.update_layout(title='3D Scatter Plot of Embeddings',
45
+ scene=dict(xaxis_title='PCA Component 1', yaxis_title='PCA Component 2', zaxis_title='PCA Component 3'))
 
 
 
 
 
 
46
 
47
  fig.update_layout(autosize=False, width=800, height=600)
48
  st.plotly_chart(fig, use_container_width=True)
 
52
  def main():
53
  st.title("Language Model Embeddings Visualization")
54
 
55
+ st.markdown("""
56
+ This application visualizes word embeddings from BERT or RoBERTa language models.
57
+ Here's how to use it:
58
+ 1. Choose a model (BERT or RoBERTa) from the dropdown menu.
59
+ 2. Enter words or phrases one at a time, or upload a CSV file with a 'word' column.
60
+ 3. View the 2D or 3D plot of the embeddings.
61
+ 4. Download the current database as a CSV file for later use.
62
+ Explore how different words relate to each other in the embedding space!
63
+ """)
64
+
65
  model_choice = st.selectbox("Choose a model:", ["BERT", "RoBERTa"])
66
  tokenizer, model = load_model(model_choice)
67
 
68
  default_word = "example"
69
+ if "words" not in st.session_state:
70
  st.session_state.words = [default_word]
 
71
  init_db()
72
  embedding = get_embeddings([default_word], tokenizer, model)[0]
73
  save_embeddings_to_db(default_word, embedding)
 
 
 
 
 
 
74
 
75
  st.write(f"Current words ({model_choice}):", ", ".join(st.session_state.words))
76
 
77
  new_word = st.text_input("Enter a new word or phrase:", "")
78
  if st.button("Add Word/Phrase"):
79
+ if new_word and new_word not in st.session_state.words:
80
  embedding = get_embeddings([new_word], tokenizer, model)[0]
81
  save_embeddings_to_db(new_word, embedding)
82
  st.session_state.words.append(new_word)
83
  st.experimental_rerun()
84
 
85
+ uploaded_file = st.file_uploader("Upload CSV file", type="csv")
86
+ if uploaded_file is not None:
87
+ df = pd.read_csv(uploaded_file)
88
+ if 'word' in df.columns:
89
+ new_words = df['word'].tolist()
90
+ for word in new_words:
91
+ if word not in st.session_state.words:
92
+ embedding = get_embeddings([word], tokenizer, model)[0]
93
+ save_embeddings_to_db(word, embedding)
94
+ st.session_state.words.append(word)
95
+ st.experimental_rerun()
96
+ else:
97
+ st.error("The CSV file must contain a 'word' column.")
98
+
99
  if st.button("Clear All Entries"):
100
  clear_all_entries()
101
  st.session_state.words = [default_word]