sahilnishad commited on
Commit
6699d1c
1 Parent(s): d59f397

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from bert_model import BERTEmbedding
3
+ from utils import build_vocab, get_embedding_tensor
4
+ from sklearn.decomposition import PCA
5
+ import plotly.express as px
6
+
7
+
8
+ def plot_2d_embeddings(embeddings, sentence):
9
+ pca = PCA(n_components=2)
10
+ reduced = pca.fit_transform(embeddings[0].detach().numpy())
11
+ fig = px.scatter(
12
+ x=reduced[:, 0],
13
+ y=reduced[:, 1],
14
+ text=sentence.split()
15
+ )
16
+ st.plotly_chart(fig)
17
+
18
+ def plot_3d_embeddings(embeddings, sentence):
19
+ pca = PCA(n_components=3)
20
+ reduced = pca.fit_transform(embeddings[0].detach().numpy())
21
+ fig = px.scatter_3d(
22
+ x=reduced[:, 0],
23
+ y=reduced[:, 1],
24
+ z=reduced[:, 2],
25
+ text=sentence.split()
26
+ )
27
+ st.plotly_chart(fig)
28
+
29
+ # Configuration
30
+ N_SEGMENTS = 2
31
+ MAX_LEN = 512
32
+ EMBED_DIM = 768
33
+ N_LAYERS = 12
34
+ ATTN_HEADS = 12
35
+ DROPOUT = 0.1
36
+
37
+
38
+ def main():
39
+ st.title("BERT Embeddings Visualization")
40
+
41
+ uploaded_file = st.file_uploader("Upload a text file to build vocabulary", type=['txt'])
42
+
43
+ if uploaded_file:
44
+ uploaded_data = uploaded_file.read().decode('utf-8').splitlines()
45
+ st.success("Vocabulary built successfully!")
46
+ else:
47
+ st.warning("Using default vocabulary.")
48
+ with open('data/default_vocab.txt', 'r') as file:
49
+ uploaded_data = file.read().splitlines()
50
+
51
+ vocab = build_vocab(uploaded_data)
52
+
53
+ VOCAB_SIZE = len(vocab)
54
+ embedding_layer = BERTEmbedding(VOCAB_SIZE, N_SEGMENTS, MAX_LEN, EMBED_DIM, DROPOUT)
55
+
56
+ user_sentence = st.text_input("Enter your sentence:", "AI in healthcare predicts patient outcomes and diagnoses.")
57
+
58
+ viz_option = st.selectbox("Select Visualization Type", ["2D", "3D"])
59
+
60
+ if st.button('Visualize Embeddings'):
61
+ embedding_tensor = get_embedding_tensor(user_sentence, vocab, embedding_layer)
62
+
63
+ if viz_option == "2D":
64
+ plot_2d_embeddings(embedding_tensor, user_sentence)
65
+ elif viz_option == "3D":
66
+ plot_3d_embeddings(embedding_tensor, user_sentence)
67
+
68
+ if __name__ == "__main__":
69
+ main()