Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import BertModel, BertTokenizer | |
import torch | |
from sklearn.decomposition import PCA | |
import plotly.graph_objs as go | |
import numpy as np | |
# BERT embeddings function | |
def get_bert_embeddings(words): | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased') | |
embeddings = [] | |
for word in words: | |
inputs = tokenizer(word, return_tensors='pt') | |
outputs = model(**inputs) | |
# Use the [CLS] token's embedding | |
cls_embedding = outputs.last_hidden_state[0][0].detach().numpy() | |
embeddings.append(cls_embedding) | |
if len(embeddings) > 0: | |
pca = PCA(n_components=3) | |
reduced_embeddings = pca.fit_transform(np.array(embeddings)) | |
return reduced_embeddings | |
return [] | |
# Plotly plotting function | |
def plot_interactive_bert_embeddings(embeddings, words): | |
if len(words) < 4: | |
st.error("Please provide at least 4 words/phrases for effective visualization.") | |
return None | |
data = [] | |
for i, word in enumerate(words): | |
trace = go.Scatter3d( | |
x=[embeddings[i][0]], | |
y=[embeddings[i][1]], | |
z=[embeddings[i][2]], | |
mode='markers+text', | |
text=[word], | |
name=word | |
) | |
data.append(trace) | |
layout = go.Layout( | |
title='3D Scatter Plot of BERT Embeddings', | |
scene=dict( | |
xaxis=dict(title='PCA Component 1'), | |
yaxis=dict(title='PCA Component 2'), | |
zaxis=dict(title='PCA Component 3') | |
), | |
autosize=False, | |
width=800, | |
height=600 | |
) | |
fig = go.Figure(data=data, layout=layout) | |
return fig | |
def main(): | |
st.title("BERT Embeddings Visualization") | |
# Initialize or get existing words list from the session state | |
if 'words' not in st.session_state: | |
st.session_state.words = [] | |
# Text input for new words | |
new_words_input = st.text_input("Enter a new word/phrase:") | |
# Button to add new words | |
if st.button("Add Word/Phrase"): | |
if new_words_input: | |
st.session_state.words.append(new_words_input) | |
st.success(f"Added: {new_words_input}") | |
# Display current list of words | |
if st.session_state.words: | |
st.write("Current list of words/phrases:", ', '.join(st.session_state.words)) | |
# Generate embeddings and plot | |
if st.button("Generate Embeddings"): | |
with st.spinner('Generating embeddings...'): | |
embeddings = get_bert_embeddings(st.session_state.words) | |
fig = plot_interactive_bert_embeddings(embeddings, st.session_state.words) | |
if fig is not None: | |
st.plotly_chart(fig, use_container_width=True) | |
# Reset button | |
if st.button("Reset"): | |
st.session_state.words = [] | |
if __name__ == "__main__": | |
main() | |