Spaces:
Sleeping
Sleeping
File size: 5,832 Bytes
4a49186 91c7a65 59c569c 4a49186 6ef9d55 4a49186 78f2519 59c569c 78f2519 59c569c 6e09ab7 78f2519 2a28595 0a0a961 80f1785 0a0a961 80f1785 6ef9d55 78f2519 fb57cc2 05c46d6 0a0a961 3e8771c 78f2519 0a0a961 78f2519 0a0a961 78f2519 91c7a65 78f2519 0a0a961 78f2519 91c7a65 78f2519 3e8771c 0a0a961 1b249fe 05c46d6 f9982c8 78f2519 80f1785 91c7a65 0a0a961 91c7a65 10aba3d f9982c8 10aba3d f9982c8 78f2519 3e8771c 0a0a961 78f2519 0a0a961 78f2519 f9982c8 6ef9d55 0a0a961 e2d7fb5 0a0a961 f9982c8 4a49186 78f2519 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import streamlit as st
import pandas as pd
from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel
from sklearn.decomposition import PCA
import plotly.graph_objs as go
import numpy as np
from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, clear_all_entries, fetch_data_as_csv
@st.cache_resource
def load_model(model_name):
if model_name == "BERT":
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
elif model_name == "RoBERTa":
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
else:
raise ValueError(f"Unsupported model: {model_name}")
return tokenizer, model
def get_embeddings(phrases, tokenizer, model):
embeddings = []
for phrase in phrases:
inputs = tokenizer(phrase, return_tensors='pt', padding=True, truncation=True)
outputs = model(**inputs)
mean_embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy()
embeddings.append(mean_embedding[0])
return np.array(embeddings)
def plot_interactive_embeddings(embeddings, phrases):
if len(phrases) >= 2:
pca = PCA(n_components=min(3, len(phrases)))
reduced_embeddings = pca.fit_transform(embeddings)
if len(phrases) == 2:
fig = go.Figure(data=[
go.Scatter(x=[emb[0]], y=[emb[1]], mode='markers+text', text=[phrase], name=phrase)
for emb, phrase in zip(reduced_embeddings, phrases)
])
fig.update_layout(title='2D Scatter Plot of Embeddings', xaxis_title='PCA Component 1', yaxis_title='PCA Component 2')
else:
fig = go.Figure(data=[
go.Scatter3d(x=[emb[0]], y=[emb[1]], z=[emb[2]], mode='markers+text', text=[phrase], name=phrase)
for emb, phrase in zip(reduced_embeddings, phrases)
])
fig.update_layout(title='3D Scatter Plot of Embeddings',
scene=dict(xaxis_title='PCA Component 1', yaxis_title='PCA Component 2', zaxis_title='PCA Component 3'))
fig.update_layout(autosize=False, width=800, height=600)
st.plotly_chart(fig, use_container_width=True)
else:
st.error("Please add at least one more phrase to visualize.")
def main():
st.set_page_config(layout="wide")
st.title("Language Model Embeddings Visualization")
st.markdown("""
This application visualizes embeddings of words and phrases from BERT or RoBERTa language models.
Explore how different words and phrases relate to each other in the embedding space!
""")
# Load model at the beginning
model_choice = "BERT" # Default model
tokenizer, model = load_model(model_choice)
# Sidebar
with st.sidebar:
st.header("Controls")
model_choice = st.selectbox("Choose a model:", ["BERT", "RoBERTa"])
if model_choice != st.session_state.get('last_model_choice'):
tokenizer, model = load_model(model_choice)
st.session_state.last_model_choice = model_choice
new_phrase = st.text_input("Enter a new word or phrase:", "")
if st.button("Add Phrase"):
if new_phrase and new_phrase not in st.session_state.phrases:
embedding = get_embeddings([new_phrase], tokenizer, model)[0]
save_embeddings_to_db(new_phrase, embedding)
st.session_state.phrases.append(new_phrase)
st.experimental_rerun()
uploaded_file = st.file_uploader("Upload CSV file", type="csv")
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
phrase_column = next((col for col in ['phrase', 'Phrase'] if col in df.columns), None)
if phrase_column:
new_phrases = df[phrase_column].dropna().unique().tolist()
for phrase in new_phrases:
if phrase and phrase not in st.session_state.phrases:
embedding = get_embeddings([phrase], tokenizer, model)[0]
save_embeddings_to_db(phrase, embedding)
st.session_state.phrases.append(phrase)
st.success(f"Successfully imported {len(new_phrases)} new phrases.")
st.experimental_rerun()
else:
st.error("The CSV file must contain a 'phrase' or 'Phrase' column.")
if st.button("Clear All Entries"):
clear_all_entries()
st.session_state.phrases = [default_phrase]
embedding = get_embeddings([default_phrase], tokenizer, model)[0]
save_embeddings_to_db(default_phrase, embedding)
st.experimental_rerun()
if st.button("Download Database as CSV"):
csv = fetch_data_as_csv()
st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
# Main area
tokenizer, model = load_model(model_choice)
default_phrase = "example"
if "phrases" not in st.session_state:
st.session_state.phrases = [default_phrase]
init_db()
embedding = get_embeddings([default_phrase], tokenizer, model)[0]
save_embeddings_to_db(default_phrase, embedding)
st.subheader(f"Current phrases ({model_choice}):")
st.write(", ".join(st.session_state.phrases))
embeddings, phrases = get_all_embeddings()
if len(embeddings) > 0:
embeddings = np.array(embeddings)
plot_interactive_embeddings(embeddings, phrases)
else:
st.info("Add phrases using the sidebar to visualize their embeddings.")
if __name__ == "__main__":
main() |