rianders's picture
Update app.py
95c80a2 verified
raw
history blame
2.9 kB
import streamlit as st
from transformers import BertModel, BertTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.graph_objs as go
import numpy as np
# BERT embeddings function
def get_bert_embeddings(words):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
embeddings = []
for word in words:
inputs = tokenizer(word, return_tensors='pt')
outputs = model(**inputs)
# Use the [CLS] token's embedding
cls_embedding = outputs.last_hidden_state[0][0].detach().numpy()
embeddings.append(cls_embedding)
if len(embeddings) > 0:
pca = PCA(n_components=3)
reduced_embeddings = pca.fit_transform(np.array(embeddings))
return reduced_embeddings
return []
# Plotly plotting function
def plot_interactive_bert_embeddings(embeddings, words):
if len(words) < 4:
st.error("Please provide at least 4 words/phrases for effective visualization.")
return None
data = []
for i, word in enumerate(words):
trace = go.Scatter3d(
x=[embeddings[i][0]],
y=[embeddings[i][1]],
z=[embeddings[i][2]],
mode='markers+text',
text=[word],
name=word
)
data.append(trace)
layout = go.Layout(
title='3D Scatter Plot of BERT Embeddings',
scene=dict(
xaxis=dict(title='PCA Component 1'),
yaxis=dict(title='PCA Component 2'),
zaxis=dict(title='PCA Component 3')
),
autosize=False,
width=800,
height=600
)
fig = go.Figure(data=data, layout=layout)
return fig
def main():
st.title("BERT Embeddings Visualization")
# Initialize or get existing words list from the session state
if 'words' not in st.session_state:
st.session_state.words = []
# Text input for new words
new_words_input = st.text_input("Enter a new word/phrase:")
# Button to add new words
if st.button("Add Word/Phrase"):
if new_words_input:
st.session_state.words.append(new_words_input)
st.success(f"Added: {new_words_input}")
# Display current list of words
if st.session_state.words:
st.write("Current list of words/phrases:", ', '.join(st.session_state.words))
# Generate embeddings and plot
if st.button("Generate Embeddings"):
with st.spinner('Generating embeddings...'):
embeddings = get_bert_embeddings(st.session_state.words)
fig = plot_interactive_bert_embeddings(embeddings, st.session_state.words)
if fig is not None:
st.plotly_chart(fig, use_container_width=True)
# Reset button
if st.button("Reset"):
st.session_state.words = []
if __name__ == "__main__":
main()