sahilnishad's picture
Update app.py
7191263
raw
history blame contribute delete
No virus
2.07 kB
import streamlit as st
from bert_model import BERTEmbedding
from utils import build_vocab, get_embedding_tensor
from sklearn.decomposition import PCA
import plotly.express as px
def plot_2d_embeddings(embeddings, sentence):
pca = PCA(n_components=2)
reduced = pca.fit_transform(embeddings[0].detach().numpy())
fig = px.scatter(
x=reduced[:, 0],
y=reduced[:, 1],
text=sentence.split()
)
st.plotly_chart(fig)
def plot_3d_embeddings(embeddings, sentence):
pca = PCA(n_components=3)
reduced = pca.fit_transform(embeddings[0].detach().numpy())
fig = px.scatter_3d(
x=reduced[:, 0],
y=reduced[:, 1],
z=reduced[:, 2],
text=sentence.split()
)
st.plotly_chart(fig)
# Configuration
N_SEGMENTS = 2
MAX_LEN = 512
EMBED_DIM = 768
N_LAYERS = 12
ATTN_HEADS = 12
DROPOUT = 0.1
def main():
st.title("BERT Embeddings Visualization")
uploaded_file = st.file_uploader("Upload a text file to build vocabulary", type=['txt'])
if uploaded_file:
uploaded_data = uploaded_file.read().decode('utf-8').splitlines()
st.success("Vocabulary built successfully!")
else:
st.warning("Using default vocabulary.")
with open('default_vocab.txt', 'r') as file:
uploaded_data = file.read().splitlines()
vocab = build_vocab(uploaded_data)
VOCAB_SIZE = len(vocab)
embedding_layer = BERTEmbedding(VOCAB_SIZE, N_SEGMENTS, MAX_LEN, EMBED_DIM, DROPOUT)
user_sentence = st.text_input("Enter your sentence:", "AI in healthcare predicts patient outcomes and diagnoses.")
viz_option = st.selectbox("Select Visualization Type", ["2D", "3D"])
if st.button('Visualize Embeddings'):
embedding_tensor = get_embedding_tensor(user_sentence, vocab, embedding_layer)
if viz_option == "2D":
plot_2d_embeddings(embedding_tensor, user_sentence)
elif viz_option == "3D":
plot_3d_embeddings(embedding_tensor, user_sentence)
if __name__ == "__main__":
main()