rianders commited on
Commit
e2d7fb5
1 Parent(s): dd4fed4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import AutoModel, BertTokenizer
3
  from sklearn.decomposition import PCA
4
  import plotly.graph_objs as go
5
  import numpy as np
@@ -7,7 +7,7 @@ from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, c
7
 
8
  # Initialize BERT model and tokenizer
9
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
- model = AutoModel.from_pretrained('bert-base-uncased')
11
 
12
  def get_bert_embeddings(words):
13
  embeddings = []
@@ -83,8 +83,8 @@ def main():
83
  st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
84
 
85
  embeddings, words = get_all_embeddings()
86
- embeddings = np.array(embeddings)
87
- if embeddings.size > 0:
88
  plot_interactive_bert_embeddings(embeddings, words)
89
 
90
  if __name__ == "__main__":
 
1
  import streamlit as st
2
+ from transformers import BertModel, BertTokenizer
3
  from sklearn.decomposition import PCA
4
  import plotly.graph_objs as go
5
  import numpy as np
 
7
 
8
  # Initialize BERT model and tokenizer
9
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
+ model = BertModel.from_pretrained('bert-base-uncased')
11
 
12
  def get_bert_embeddings(words):
13
  embeddings = []
 
83
  st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
84
 
85
  embeddings, words = get_all_embeddings()
86
+ if len(embeddings) > 0:
87
+ embeddings = np.array(embeddings)
88
  plot_interactive_bert_embeddings(embeddings, words)
89
 
90
  if __name__ == "__main__":