File size: 5,728 Bytes
c8d36ae
 
 
597bf7d
 
 
 
 
7a75a86
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb9cb6e
 
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2918df9
597bf7d
 
 
 
 
 
 
 
 
 
 
2918df9
597bf7d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples marked by a small black border.
"""
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st

from src.subpages.page import Context, Page


@st.cache
def reduce_dim_svd(X, n_iter, random_state=42):
    from sklearn.decomposition import TruncatedSVD

    svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
    return svd.fit_transform(X)


@st.cache
def reduce_dim_pca(X, random_state=42):
    from sklearn.decomposition import PCA

    return PCA(n_components=2, random_state=random_state).fit_transform(X)


@st.cache
def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
    from umap import UMAP

    return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)


class HiddenStatesPage(Page):
    name = "Hidden States"
    icon = "grid-3x3"

    def get_widget_defaults(self):
        return {
            "n_tokens": 1_000,
            "svd_n_iter": 5,
            "svd_random_state": 42,
            "umap_n_neighbors": 15,
            "umap_metric": "euclidean",
            "umap_min_dist": 0.1,
        }

    def render(self, context: Context):
        st.title("Embeddings")

        with st.expander("💡", expanded=True):
            st.write(
                "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples signified by a small black border."
            )

        col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
        df = context.df_tokens_merged.copy()
        dim_algo = "SVD"
        n_tokens = 100

        with col1:
            st.subheader("Settings")
            n_tokens = st.slider(
                "#tokens",
                key="n_tokens",
                min_value=100,
                max_value=len(df["tokens"].unique()),
                step=100,
            )

            dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
            if dim_algo == "SVD":
                svd_n_iter = st.slider(
                    "#iterations",
                    key="svd_n_iter",
                    min_value=1,
                    max_value=10,
                    step=1,
                )
            elif dim_algo == "UMAP":
                umap_n_neighbors = st.slider(
                    "#neighbors",
                    key="umap_n_neighbors",
                    min_value=2,
                    max_value=100,
                    step=1,
                )
                umap_min_dist = st.number_input(
                    "Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
                )
                umap_metric = st.selectbox(
                    "Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
                )
            else:
                pass

        with col2:
            sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))

            X = np.array(df["hidden_states"].tolist())
            transformed_hidden_states = None
            if dim_algo == "SVD":
                transformed_hidden_states = reduce_dim_svd(X, n_iter=svd_n_iter)  # type: ignore
            elif dim_algo == "PCA":
                transformed_hidden_states = reduce_dim_pca(X)
            elif dim_algo == "UMAP":
                transformed_hidden_states = reduce_dim_umap(
                    X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric  # type: ignore
                )

            assert isinstance(transformed_hidden_states, np.ndarray)
            df["x"] = transformed_hidden_states[:, 0]
            df["y"] = transformed_hidden_states[:, 1]
            df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
            df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
            df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
            df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
            df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
            df["mislabeled"] = df["labels"] != df["preds"]

            subset = df[:n_tokens]
            mislabeled_examples_trace = go.Scatter(
                x=subset[subset["mislabeled"]]["x"],
                y=subset[subset["mislabeled"]]["y"],
                mode="markers",
                marker=dict(
                    size=6,
                    color="rgba(0,0,0,0)",
                    line=dict(width=1),
                ),
                hoverinfo="skip",
            )

            st.subheader("Projection Results")

            fig = px.scatter(
                subset,
                x="x",
                y="y",
                color="labels",
                hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"],
                hover_name="tokens",
                title="Colored by label",
            )
            fig.add_trace(mislabeled_examples_trace)
            st.plotly_chart(fig)

            fig = px.scatter(
                subset,
                x="x",
                y="y",
                color="preds",
                hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"],
                hover_name="tokens",
                title="Colored by prediction",
            )
            fig.add_trace(mislabeled_examples_trace)
            st.plotly_chart(fig)