paulokewunmi's picture
Upload 7 files
33dca10
raw
history blame contribute delete
No virus
1.51 kB
import streamlit as st
from joblib import load
import pandas as pd
import plotly.express as px
reduced_embeddings = load("embeddings_tsne.joblib")
combined_sentences = load("sentence_list.joblib")
languages = load("language_list.joblib")
clusters = load("clusters.joblib")
def main():
st.title("LASER Multilingual Sentence Embeddings Visualization")
df = pd.DataFrame({
'TSNE Component 1': reduced_embeddings[:, 0],
'TSNE Component 2': reduced_embeddings[:, 1],
'Language': languages,
'Sentence': combined_sentences,
'Cluster': ['Cluster {}'.format(cluster) for cluster in clusters]
})
select_all = st.checkbox("Select All Clusters")
unique_clusters = df['Cluster'].unique()
# Active only if 'Select All Clusters' is unchecked
if not select_all:
selected_clusters = st.multiselect("Select clusters to display", unique_clusters, default=unique_clusters[:10])
else:
selected_clusters = unique_clusters
filtered_df = df[df['Cluster'].isin(selected_clusters)]
fig = px.scatter(filtered_df, x='TSNE Component 1', y='TSNE Component 2',
color='Language', hover_data=['Sentence', 'Cluster'])
fig.update_layout(title="Multilingual Sentence Embeddings Visualization",
xaxis_title="TSNE Component 1", yaxis_title="TSNE Component 2",
legend_title="Language")
st.plotly_chart(fig, use_container_width=True)
if __name__ == "__main__":
main()