import streamlit as st from transformers import BertModel, BertTokenizer import torch from sklearn.decomposition import PCA import plotly.graph_objs as go import numpy as np # BERT embeddings function def get_bert_embeddings(words): tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased') embeddings = [] for word in words: inputs = tokenizer(word, return_tensors='pt') outputs = model(**inputs) # Use the [CLS] token's embedding cls_embedding = outputs.last_hidden_state[0][0].detach().numpy() embeddings.append(cls_embedding) if len(embeddings) > 0: pca = PCA(n_components=3) reduced_embeddings = pca.fit_transform(np.array(embeddings)) return reduced_embeddings return [] # Plotly plotting function def plot_interactive_bert_embeddings(embeddings, words): if len(words) < 4: st.error("Please provide at least 4 words/phrases for effective visualization.") return None data = [] for i, word in enumerate(words): trace = go.Scatter3d( x=[embeddings[i][0]], y=[embeddings[i][1]], z=[embeddings[i][2]], mode='markers+text', text=[word], name=word ) data.append(trace) layout = go.Layout( title='3D Scatter Plot of BERT Embeddings', scene=dict( xaxis=dict(title='PCA Component 1'), yaxis=dict(title='PCA Component 2'), zaxis=dict(title='PCA Component 3') ), autosize=False, width=800, height=600 ) fig = go.Figure(data=data, layout=layout) return fig def main(): st.title("BERT Embeddings Visualization") # Initialize or get existing words list from the session state if 'words' not in st.session_state: st.session_state.words = [] # Text input for new words new_words_input = st.text_input("Enter a new word/phrase:") # Button to add new words if st.button("Add Word/Phrase"): if new_words_input: st.session_state.words.append(new_words_input) st.success(f"Added: {new_words_input}") # Display current list of words if st.session_state.words: st.write("Current list of words/phrases:", ', '.join(st.session_state.words)) # Generate embeddings and plot if st.button("Generate Embeddings"): with st.spinner('Generating embeddings...'): embeddings = get_bert_embeddings(st.session_state.words) fig = plot_interactive_bert_embeddings(embeddings, st.session_state.words) if fig is not None: st.plotly_chart(fig, use_container_width=True) # Reset button if st.button("Reset"): st.session_state.words = [] if __name__ == "__main__": main()