Spaces:
Runtime error
Runtime error
File size: 5,479 Bytes
597bf7d fb9cb6e 597bf7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from subpages.page import Context, Page
@st.cache
def reduce_dim_svd(X, n_iter, random_state=42):
from sklearn.decomposition import TruncatedSVD
svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
return svd.fit_transform(X)
@st.cache
def reduce_dim_pca(X, random_state=42):
from sklearn.decomposition import PCA
return PCA(n_components=2, random_state=random_state).fit_transform(X)
@st.cache
def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
from umap import UMAP
return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
class HiddenStatesPage(Page):
name = "Hidden States"
icon = "grid-3x3"
def get_widget_defaults(self):
return {
"n_tokens": 1_000,
"svd_n_iter": 5,
"svd_random_state": 42,
"umap_n_neighbors": 15,
"umap_metric": "euclidean",
"umap_min_dist": 0.1,
}
def render(self, context: Context):
st.title("Embeddings")
with st.expander("💡", expanded=True):
st.write(
"For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples signified by a small black border."
)
col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
df = context.df_tokens_merged.copy()
dim_algo = "SVD"
n_tokens = 100
with col1:
st.subheader("Settings")
n_tokens = st.slider(
"#tokens",
key="n_tokens",
min_value=100,
max_value=len(df["tokens"].unique()),
step=100,
)
dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
if dim_algo == "SVD":
svd_n_iter = st.slider(
"#iterations",
key="svd_n_iter",
min_value=1,
max_value=10,
step=1,
)
elif dim_algo == "UMAP":
umap_n_neighbors = st.slider(
"#neighbors",
key="umap_n_neighbors",
min_value=2,
max_value=100,
step=1,
)
umap_min_dist = st.number_input(
"Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
)
umap_metric = st.selectbox(
"Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
)
else:
pass
with col2:
sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
X = np.array(df["hidden_states"].tolist())
transformed_hidden_states = None
if dim_algo == "SVD":
transformed_hidden_states = reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore
elif dim_algo == "PCA":
transformed_hidden_states = reduce_dim_pca(X)
elif dim_algo == "UMAP":
transformed_hidden_states = reduce_dim_umap(
X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore
)
assert isinstance(transformed_hidden_states, np.ndarray)
df["x"] = transformed_hidden_states[:, 0]
df["y"] = transformed_hidden_states[:, 1]
df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
df["mislabeled"] = df["labels"] != df["preds"]
subset = df[:n_tokens]
mislabeled_examples_trace = go.Scatter(
x=subset[subset["mislabeled"]]["x"],
y=subset[subset["mislabeled"]]["y"],
mode="markers",
marker=dict(
size=6,
color="rgba(0,0,0,0)",
line=dict(width=1),
),
hoverinfo="skip",
)
st.subheader("Projection Results")
fig = px.scatter(
subset,
x="x",
y="y",
color="labels",
hover_data=["sent0", "sent1", "sent2", "sent3", "sent4"],
hover_name="tokens",
title="Colored by label",
)
fig.add_trace(mislabeled_examples_trace)
st.plotly_chart(fig)
fig = px.scatter(
subset,
x="x",
y="y",
color="preds",
hover_data=["sent0", "sent1", "sent2", "sent3", "sent4"],
hover_name="tokens",
title="Colored by prediction",
)
fig.add_trace(mislabeled_examples_trace)
st.plotly_chart(fig)
|