File size: 2,896 Bytes
4a49186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c80a2
 
 
4a49186
 
 
 
 
 
 
95c80a2
4a49186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
from transformers import BertModel, BertTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.graph_objs as go
import numpy as np

# BERT embeddings function
def get_bert_embeddings(words):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased')
    embeddings = []

    for word in words:
        inputs = tokenizer(word, return_tensors='pt')
        outputs = model(**inputs)
        # Use the [CLS] token's embedding
        cls_embedding = outputs.last_hidden_state[0][0].detach().numpy()
        embeddings.append(cls_embedding)

    if len(embeddings) > 0:
        pca = PCA(n_components=3)
        reduced_embeddings = pca.fit_transform(np.array(embeddings))
        return reduced_embeddings
    return []


# Plotly plotting function
def plot_interactive_bert_embeddings(embeddings, words):
    if len(words) < 4:
        st.error("Please provide at least 4 words/phrases for effective visualization.")
        return None

    data = []
    for i, word in enumerate(words):
        trace = go.Scatter3d(
            x=[embeddings[i][0]], 
            y=[embeddings[i][1]], 
            z=[embeddings[i][2]],
            mode='markers+text',
            text=[word],
            name=word
        )
        data.append(trace)

    layout = go.Layout(
        title='3D Scatter Plot of BERT Embeddings',
        scene=dict(
            xaxis=dict(title='PCA Component 1'),
            yaxis=dict(title='PCA Component 2'),
            zaxis=dict(title='PCA Component 3')
        ),
        autosize=False,
        width=800,
        height=600
    )

    fig = go.Figure(data=data, layout=layout)
    return fig

def main():
    st.title("BERT Embeddings Visualization")

    # Initialize or get existing words list from the session state
    if 'words' not in st.session_state:
        st.session_state.words = []

    # Text input for new words
    new_words_input = st.text_input("Enter a new word/phrase:")

    # Button to add new words
    if st.button("Add Word/Phrase"):
        if new_words_input:
            st.session_state.words.append(new_words_input)
            st.success(f"Added: {new_words_input}")

    # Display current list of words
    if st.session_state.words:
        st.write("Current list of words/phrases:", ', '.join(st.session_state.words))

    # Generate embeddings and plot
    if st.button("Generate Embeddings"):
        with st.spinner('Generating embeddings...'):
            embeddings = get_bert_embeddings(st.session_state.words)
            fig = plot_interactive_bert_embeddings(embeddings, st.session_state.words)
            if fig is not None:
                st.plotly_chart(fig, use_container_width=True)

    # Reset button
    if st.button("Reset"):
        st.session_state.words = []

if __name__ == "__main__":
    main()