Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel | |
| from sklearn.decomposition import PCA | |
| import plotly.graph_objs as go | |
| import numpy as np | |
| from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, clear_all_entries, fetch_data_as_csv | |
| def load_model(model_name): | |
| if model_name == "BERT": | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| model = BertModel.from_pretrained('bert-base-uncased') | |
| elif model_name == "RoBERTa": | |
| tokenizer = RobertaTokenizer.from_pretrained('roberta-base') | |
| model = RobertaModel.from_pretrained('roberta-base') | |
| else: | |
| raise ValueError(f"Unsupported model: {model_name}") | |
| return tokenizer, model | |
| def get_embeddings(phrases, tokenizer, model): | |
| embeddings = [] | |
| for phrase in phrases: | |
| inputs = tokenizer(phrase, return_tensors='pt', padding=True, truncation=True) | |
| outputs = model(**inputs) | |
| mean_embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy() | |
| embeddings.append(mean_embedding[0]) | |
| return np.array(embeddings) | |
| def plot_interactive_embeddings(embeddings, phrases): | |
| if len(phrases) >= 2: | |
| pca = PCA(n_components=min(3, len(phrases))) | |
| reduced_embeddings = pca.fit_transform(embeddings) | |
| if len(phrases) == 2: | |
| fig = go.Figure(data=[ | |
| go.Scatter(x=[emb[0]], y=[emb[1]], mode='markers+text', text=[phrase], name=phrase) | |
| for emb, phrase in zip(reduced_embeddings, phrases) | |
| ]) | |
| fig.update_layout(title='2D Scatter Plot of Embeddings', xaxis_title='PCA Component 1', yaxis_title='PCA Component 2') | |
| else: | |
| fig = go.Figure(data=[ | |
| go.Scatter3d(x=[emb[0]], y=[emb[1]], z=[emb[2]], mode='markers+text', text=[phrase], name=phrase) | |
| for emb, phrase in zip(reduced_embeddings, phrases) | |
| ]) | |
| fig.update_layout(title='3D Scatter Plot of Embeddings', | |
| scene=dict(xaxis_title='PCA Component 1', yaxis_title='PCA Component 2', zaxis_title='PCA Component 3')) | |
| fig.update_layout(autosize=False, width=800, height=600) | |
| st.plotly_chart(fig, use_container_width=True) | |
| else: | |
| st.error("Please add at least one more phrase to visualize.") | |
| def main(): | |
| st.set_page_config(layout="wide") | |
| st.title("Language Model Embeddings Visualization") | |
| st.markdown(""" | |
| This application visualizes embeddings of words and phrases from BERT or RoBERTa language models. | |
| Explore how different words and phrases relate to each other in the embedding space! | |
| """) | |
| # Load model at the beginning | |
| model_choice = "BERT" # Default model | |
| tokenizer, model = load_model(model_choice) | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("Controls") | |
| model_choice = st.selectbox("Choose a model:", ["BERT", "RoBERTa"]) | |
| if model_choice != st.session_state.get('last_model_choice'): | |
| tokenizer, model = load_model(model_choice) | |
| st.session_state.last_model_choice = model_choice | |
| new_phrase = st.text_input("Enter a new word or phrase:", "") | |
| if st.button("Add Phrase"): | |
| if new_phrase and new_phrase not in st.session_state.phrases: | |
| embedding = get_embeddings([new_phrase], tokenizer, model)[0] | |
| save_embeddings_to_db(new_phrase, embedding) | |
| st.session_state.phrases.append(new_phrase) | |
| st.experimental_rerun() | |
| uploaded_file = st.file_uploader("Upload CSV file", type="csv") | |
| if uploaded_file is not None: | |
| df = pd.read_csv(uploaded_file) | |
| phrase_column = next((col for col in ['phrase', 'Phrase'] if col in df.columns), None) | |
| if phrase_column: | |
| new_phrases = df[phrase_column].dropna().unique().tolist() | |
| for phrase in new_phrases: | |
| if phrase and phrase not in st.session_state.phrases: | |
| embedding = get_embeddings([phrase], tokenizer, model)[0] | |
| save_embeddings_to_db(phrase, embedding) | |
| st.session_state.phrases.append(phrase) | |
| st.success(f"Successfully imported {len(new_phrases)} new phrases.") | |
| st.experimental_rerun() | |
| else: | |
| st.error("The CSV file must contain a 'phrase' or 'Phrase' column.") | |
| if st.button("Clear All Entries"): | |
| clear_all_entries() | |
| st.session_state.phrases = [default_phrase] | |
| embedding = get_embeddings([default_phrase], tokenizer, model)[0] | |
| save_embeddings_to_db(default_phrase, embedding) | |
| st.experimental_rerun() | |
| if st.button("Download Database as CSV"): | |
| csv = fetch_data_as_csv() | |
| st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv') | |
| # Main area | |
| tokenizer, model = load_model(model_choice) | |
| default_phrase = "example" | |
| if "phrases" not in st.session_state: | |
| st.session_state.phrases = [default_phrase] | |
| init_db() | |
| embedding = get_embeddings([default_phrase], tokenizer, model)[0] | |
| save_embeddings_to_db(default_phrase, embedding) | |
| st.subheader(f"Current phrases ({model_choice}):") | |
| st.write(", ".join(st.session_state.phrases)) | |
| embeddings, phrases = get_all_embeddings() | |
| if len(embeddings) > 0: | |
| embeddings = np.array(embeddings) | |
| plot_interactive_embeddings(embeddings, phrases) | |
| else: | |
| st.info("Add phrases using the sidebar to visualize their embeddings.") | |
| if __name__ == "__main__": | |
| main() |