File size: 5,832 Bytes
4a49186
91c7a65
59c569c
4a49186
 
 
6ef9d55
4a49186
78f2519
 
 
59c569c
 
78f2519
59c569c
 
6e09ab7
 
 
78f2519
2a28595
0a0a961
80f1785
0a0a961
 
80f1785
6ef9d55
78f2519
fb57cc2
05c46d6
0a0a961
 
 
3e8771c
78f2519
0a0a961
78f2519
0a0a961
 
78f2519
91c7a65
78f2519
 
0a0a961
 
78f2519
91c7a65
 
78f2519
 
3e8771c
 
0a0a961
1b249fe
05c46d6
f9982c8
78f2519
80f1785
91c7a65
0a0a961
 
91c7a65
 
10aba3d
 
 
 
 
f9982c8
 
 
 
10aba3d
 
 
 
f9982c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78f2519
3e8771c
0a0a961
 
 
78f2519
0a0a961
 
78f2519
f9982c8
 
6ef9d55
0a0a961
e2d7fb5
 
0a0a961
f9982c8
 
4a49186
 
78f2519
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import streamlit as st
import pandas as pd
from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel
from sklearn.decomposition import PCA
import plotly.graph_objs as go
import numpy as np
from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, clear_all_entries, fetch_data_as_csv

@st.cache_resource
def load_model(model_name):
    if model_name == "BERT":
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertModel.from_pretrained('bert-base-uncased')
    elif model_name == "RoBERTa":
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
        model = RobertaModel.from_pretrained('roberta-base')
    else:
        raise ValueError(f"Unsupported model: {model_name}")
    
    return tokenizer, model

def get_embeddings(phrases, tokenizer, model):
    embeddings = []
    for phrase in phrases:
        inputs = tokenizer(phrase, return_tensors='pt', padding=True, truncation=True)
        outputs = model(**inputs)
        mean_embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy()
        embeddings.append(mean_embedding[0])
    return np.array(embeddings)

def plot_interactive_embeddings(embeddings, phrases):
    if len(phrases) >= 2:
        pca = PCA(n_components=min(3, len(phrases)))
        reduced_embeddings = pca.fit_transform(embeddings)
        
        if len(phrases) == 2:
            fig = go.Figure(data=[
                go.Scatter(x=[emb[0]], y=[emb[1]], mode='markers+text', text=[phrase], name=phrase)
                for emb, phrase in zip(reduced_embeddings, phrases)
            ])
            fig.update_layout(title='2D Scatter Plot of Embeddings', xaxis_title='PCA Component 1', yaxis_title='PCA Component 2')
        else:
            fig = go.Figure(data=[
                go.Scatter3d(x=[emb[0]], y=[emb[1]], z=[emb[2]], mode='markers+text', text=[phrase], name=phrase)
                for emb, phrase in zip(reduced_embeddings, phrases)
            ])
            fig.update_layout(title='3D Scatter Plot of Embeddings',
                              scene=dict(xaxis_title='PCA Component 1', yaxis_title='PCA Component 2', zaxis_title='PCA Component 3'))
        
        fig.update_layout(autosize=False, width=800, height=600)
        st.plotly_chart(fig, use_container_width=True)
    else:
        st.error("Please add at least one more phrase to visualize.")

def main():
    st.set_page_config(layout="wide")
    st.title("Language Model Embeddings Visualization")

    st.markdown("""
    This application visualizes embeddings of words and phrases from BERT or RoBERTa language models. 
    Explore how different words and phrases relate to each other in the embedding space!
    """)

    # Load model at the beginning
    model_choice = "BERT"  # Default model
    tokenizer, model = load_model(model_choice)

    
    # Sidebar
    with st.sidebar:
        st.header("Controls")
        model_choice = st.selectbox("Choose a model:", ["BERT", "RoBERTa"])

        if model_choice != st.session_state.get('last_model_choice'):
            tokenizer, model = load_model(model_choice)
            st.session_state.last_model_choice = model_choice
        
        new_phrase = st.text_input("Enter a new word or phrase:", "")
        if st.button("Add Phrase"):
            if new_phrase and new_phrase not in st.session_state.phrases:
                embedding = get_embeddings([new_phrase], tokenizer, model)[0]
                save_embeddings_to_db(new_phrase, embedding)
                st.session_state.phrases.append(new_phrase)
                st.experimental_rerun()
        
        uploaded_file = st.file_uploader("Upload CSV file", type="csv")
        if uploaded_file is not None:
            df = pd.read_csv(uploaded_file)
            phrase_column = next((col for col in ['phrase', 'Phrase'] if col in df.columns), None)
            if phrase_column:
                new_phrases = df[phrase_column].dropna().unique().tolist()
                for phrase in new_phrases:
                    if phrase and phrase not in st.session_state.phrases:
                        embedding = get_embeddings([phrase], tokenizer, model)[0]
                        save_embeddings_to_db(phrase, embedding)
                        st.session_state.phrases.append(phrase)
                st.success(f"Successfully imported {len(new_phrases)} new phrases.")
                st.experimental_rerun()
            else:
                st.error("The CSV file must contain a 'phrase' or 'Phrase' column.")
        
        if st.button("Clear All Entries"):
            clear_all_entries()
            st.session_state.phrases = [default_phrase]
            embedding = get_embeddings([default_phrase], tokenizer, model)[0]
            save_embeddings_to_db(default_phrase, embedding)
            st.experimental_rerun()
        
        if st.button("Download Database as CSV"):
            csv = fetch_data_as_csv()
            st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')

    # Main area
    tokenizer, model = load_model(model_choice)

    default_phrase = "example"
    if "phrases" not in st.session_state:
        st.session_state.phrases = [default_phrase]
        init_db()
        embedding = get_embeddings([default_phrase], tokenizer, model)[0]
        save_embeddings_to_db(default_phrase, embedding)

    st.subheader(f"Current phrases ({model_choice}):")
    st.write(", ".join(st.session_state.phrases))

    embeddings, phrases = get_all_embeddings()
    if len(embeddings) > 0:
        embeddings = np.array(embeddings)
        plot_interactive_embeddings(embeddings, phrases)
    else:
        st.info("Add phrases using the sidebar to visualize their embeddings.")

if __name__ == "__main__":
    main()