|
import streamlit as st |
|
|
|
from joblib import load |
|
import pandas as pd |
|
import plotly.express as px |
|
|
|
|
|
reduced_embeddings = load("embeddings_tsne.joblib") |
|
combined_sentences = load("sentence_list.joblib") |
|
languages = load("language_list.joblib") |
|
clusters = load("clusters.joblib") |
|
|
|
def main(): |
|
st.title("LASER Multilingual Sentence Embeddings Visualization") |
|
|
|
df = pd.DataFrame({ |
|
'TSNE Component 1': reduced_embeddings[:, 0], |
|
'TSNE Component 2': reduced_embeddings[:, 1], |
|
'Language': languages, |
|
'Sentence': combined_sentences, |
|
'Cluster': ['Cluster {}'.format(cluster) for cluster in clusters] |
|
}) |
|
|
|
select_all = st.checkbox("Select All Clusters") |
|
|
|
unique_clusters = df['Cluster'].unique() |
|
|
|
|
|
if not select_all: |
|
selected_clusters = st.multiselect("Select clusters to display", unique_clusters, default=unique_clusters[:10]) |
|
else: |
|
selected_clusters = unique_clusters |
|
|
|
filtered_df = df[df['Cluster'].isin(selected_clusters)] |
|
|
|
fig = px.scatter(filtered_df, x='TSNE Component 1', y='TSNE Component 2', |
|
color='Language', hover_data=['Sentence', 'Cluster']) |
|
fig.update_layout(title="Multilingual Sentence Embeddings Visualization", |
|
xaxis_title="TSNE Component 1", yaxis_title="TSNE Component 2", |
|
legend_title="Language") |
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
if __name__ == "__main__": |
|
main() |