Alexander Seifert commited on
Commit
408486e
1 Parent(s): 4c13f39

improve docs

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. html/index.md +1 -1
  3. src/subpages/hidden_states.py +8 -8
README.md CHANGED
@@ -24,7 +24,7 @@ A group of neurons tend to fire in response to commas and other punctuation. Oth
24
 
25
  ### Embeddings
26
 
27
- For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples marked by a small black border.
28
 
29
 
30
  ### Probing
24
 
25
  ### Embeddings
26
 
27
+ For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
28
 
29
 
30
  ### Probing
html/index.md CHANGED
@@ -51,7 +51,7 @@ Activations
51
 
52
  Hidden States
53
 
54
- > For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples marked by a small black border.
55
  >
56
  > Using these projections you can visually identify data points that end up in the wrong neighborhood, indicating prediction/labeling errors.
57
 
51
 
52
  Hidden States
53
 
54
+ > For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
55
  >
56
  > Using these projections you can visually identify data points that end up in the wrong neighborhood, indicating prediction/labeling errors.
57
 
src/subpages/hidden_states.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples marked by a small black border.
3
  """
4
  import numpy as np
5
  import plotly.express as px
@@ -86,7 +86,7 @@ class HiddenStatesPage(Page):
86
 
87
  with st.expander("💡", expanded=True):
88
  st.write(
89
- "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples signified by a small black border."
90
  )
91
 
92
  col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
@@ -152,12 +152,12 @@ class HiddenStatesPage(Page):
152
  df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
153
  df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
154
  df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
155
- df["mislabeled"] = df["labels"] != df["preds"]
156
 
157
  subset = df[:n_tokens]
158
- mislabeled_examples_trace = go.Scatter(
159
- x=subset[subset["mislabeled"]]["x"],
160
- y=subset[subset["mislabeled"]]["y"],
161
  mode="markers",
162
  marker=dict(
163
  size=6,
@@ -178,7 +178,7 @@ class HiddenStatesPage(Page):
178
  hover_name="tokens",
179
  title="Colored by label",
180
  )
181
- fig.add_trace(mislabeled_examples_trace)
182
  st.plotly_chart(fig)
183
 
184
  fig = px.scatter(
@@ -190,5 +190,5 @@ class HiddenStatesPage(Page):
190
  hover_name="tokens",
191
  title="Colored by prediction",
192
  )
193
- fig.add_trace(mislabeled_examples_trace)
194
  st.plotly_chart(fig)
1
  """
2
+ For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
3
  """
4
  import numpy as np
5
  import plotly.express as px
86
 
87
  with st.expander("💡", expanded=True):
88
  st.write(
89
+ "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements signified by a small black border."
90
  )
91
 
92
  col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
152
  df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
153
  df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
154
  df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
155
+ df["disagreements"] = df["labels"] != df["preds"]
156
 
157
  subset = df[:n_tokens]
158
+ disagreements_trace = go.Scatter(
159
+ x=subset[subset["disagreements"]]["x"],
160
+ y=subset[subset["disagreements"]]["y"],
161
  mode="markers",
162
  marker=dict(
163
  size=6,
178
  hover_name="tokens",
179
  title="Colored by label",
180
  )
181
+ fig.add_trace(disagreements_trace)
182
  st.plotly_chart(fig)
183
 
184
  fig = px.scatter(
190
  hover_name="tokens",
191
  title="Colored by prediction",
192
  )
193
+ fig.add_trace(disagreements_trace)
194
  st.plotly_chart(fig)