import numpy as np import plotly.express as px import plotly.graph_objects as go import streamlit as st from src.subpages.page import Context, Page class HiddenStatesVisualizer: def __init__(self, context: Context): self.context = context self.df = context.df_tokens_merged.copy() def _reduce_dim_svd(self, X, n_iter: int, random_state=42): # Implement your SVD reduction here pass def _reduce_dim_pca(self, X, random_state=42): # Implement your PCA reduction here pass def _reduce_dim_umap(self, X, n_neighbors=5, min_dist=0.1, metric="euclidean"): # Implement your UMAP reduction here pass def visualize_hidden_states(self): st.title("Embeddings") with st.expander("💡", expanded=True): st.write( "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements signified by a small black border." ) col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32]) dim_algo = "SVD" n_tokens = 100 with col1: st.subheader("Settings") n_tokens = st.slider( "#tokens", key="n_tokens", min_value=100, max_value=len(self.df["tokens"].unique()), step=100, ) dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"]) if dim_algo == "SVD": svd_n_iter = st.slider( "#iterations", key="svd_n_iter", min_value=1, max_value=10, step=1, ) elif dim_algo == "UMAP": umap_n_neighbors = st.slider( "#neighbors", key="umap_n_neighbors", min_value=2, max_value=100, step=1, ) umap_min_dist = st.number_input( "Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0 ) umap_metric = st.selectbox( "Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"] ) else: pass with col2: sents = self.df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist())) X = np.array(self.df["hidden_states"].tolist()) transformed_hidden_states = None if dim_algo == "SVD": transformed_hidden_states = self._reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore elif dim_algo == "PCA": transformed_hidden_states = self._reduce_dim_pca(X) elif dim_algo == "UMAP": transformed_hidden_states = self._reduce_dim_umap( X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore ) assert isinstance(transformed_hidden_states, np.ndarray) self.df["x"] = transformed_hidden_states[:, 0] self.df["y"] = transformed_hidden_states[:, 1] self.df["sent0"] = self.df["ids"].map(lambda x: " ".join(sents[x][0:50].split())) self.df["sent1"] = self.df["ids"].map(lambda x: " ".join(sents[x][50:100].split())) self.df["sent2"] = self.df["ids"].map(lambda x: " ".join(sents[x][100:150].split())) self.df["sent3"] = self.df["ids"].map(lambda x: " ".join(sents[x][150:200].split())) self.df["sent4"] = self.df["ids"].map(lambda x: " ".join(sents[x][200:250].split())) self.df["disagreements"] = self.df["labels"] != self.df["preds"] subset = self.df[:n_tokens] disagreements_trace = go.Scatter( x=subset[subset["disagreements"]]["x"], y=subset[subset["disagreements"]]["y"], mode="markers", marker=dict( size=6, color="rgba(0,0,0,0)", line=dict(width=1), ), hoverinfo="skip", ) st.subheader("Projection Results") fig = px.scatter( subset, x="x", y="y", color="labels", hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"], hover_name="tokens", title="Colored by label", ) fig.add_trace(disagreements_trace) st.plotly_chart(fig) fig = px.scatter( subset, x="x", y="y", color="preds", hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"], hover_name="tokens", title="Colored by prediction", ) fig.add_trace(disagreements_trace) st.plotly_chart(fig) class HiddenStatesPage(Page): name = "Hidden States" icon = "grid-3x3" def render(self, context: Context): visualizer = HiddenStatesVisualizer(context) visualizer.visualize_hidden_states()