import json from Levenshtein import distance import streamlit as st import numpy as np import plotly.express as px from sklearn.decomposition import PCA def load_data(): embeddings = np.load("data/simplesegmentT5_embeddings.npy") words = json.load(open("data/words.json", "r")) return embeddings, words def project_embeddings(embeddings): pca = PCA(n_components=3) proj = pca.fit_transform(embeddings) return proj def filter_words(words, remove_capitalized, length): idx = [] for i, w in enumerate(words): if remove_capitalized and w.lower() != w: continue if len(w) < length[0] or len(w) > length[1]: continue idx.append(i) return idx def color_length(words): return [len(w) for w in words] def color_first_letter(words): return [min(1, max(0, (ord(w.lower()[0]) - 97) / 26)) for w in words] def color_levenshtein(words): return [distance(w, words[4]) for w in words] def plot_scatter(words, embeddings, remove_capitalized, length, color_select): idx = filter_words(words, remove_capitalized, length) filtered_embeddings = embeddings[idx] filtered_words = [words[i] for i in idx] proj = project_embeddings(filtered_embeddings) if color_select == "Word length": color = color_length(filtered_words) else: color = color_levenshtein(filtered_words) fig = px.scatter_3d( x=proj[:, 0], y=proj[:, 1], z=proj[:, 2], width=800, height=600, color=color, color_continuous_scale=px.colors.sequential.Viridis, hover_name=filtered_words, title="SimpleSegmentT5 Embeddings", ) fig.update_traces( marker={"size": 6, "line": {"width": 2}}, selector={"mode": "markers"}, ) return fig def main(): embeddings, words = load_data() proj = project_embeddings(embeddings) fig = px.scatter_3d( x=proj[:, 0], y=proj[:, 1], z=proj[:, 2], color=[len(w) for w in words], hover_name=words, title="SimpleSegmentT5 Embeddings", ) st.sidebar.title("Settings") remove_checkbox = st.sidebar.checkbox( "Remove capitalized words", value=True, key="include_capitalized", ) length_slider = st.sidebar.slider("Word length", 3, 9, (3, 9)) color_select = st.sidebar.radio("Color by", ["Word length", "Levenshtein distance to random word"]) scatter = st.plotly_chart(plot_scatter(words, embeddings, remove_checkbox, length_slider, color_select)) if __name__ == "__main__": main()