Alexander Seifert commited on
Commit
fb9cb6e
β€’
1 Parent(s): 17ba05a
main.py CHANGED
@@ -17,8 +17,9 @@ from subpages import (
17
  RawDataPage,
18
  )
19
  from subpages.attention import AttentionPage
20
- from subpages.embeddings import EmbeddingsPage
21
  from subpages.inspect import InspectPage
 
22
 
23
  sts = st.sidebar
24
  st.set_page_config(
@@ -54,25 +55,11 @@ def _write_color_legend(context):
54
  def style(x):
55
  return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]
56
 
57
- labelmap = {
58
- "O": "O",
59
- "person": "πŸ™Ž",
60
- "PER": "πŸ™Ž",
61
- "location": "🌎",
62
- "LOC": "🌎",
63
- "corporation": "🏀",
64
- "ORG": "🏀",
65
- "product": "πŸ“±",
66
- "creative": "🎷",
67
- "group": "🎷",
68
- "MISC": "🎷",
69
- }
70
-
71
  labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
72
  colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]
73
 
74
  color_legend_df = pd.DataFrame(
75
- [labelmap[l] for l in labels], columns=["label"], index=labels
76
  ).T
77
  st.sidebar.write(
78
  color_legend_df.T.style.apply(style, axis=0).set_properties(
@@ -85,7 +72,7 @@ def main():
85
  pages: list[Page] = [
86
  HomePage(),
87
  AttentionPage(),
88
- EmbeddingsPage(),
89
  ProbingPage(),
90
  MetricsPage(),
91
  MisclassifiedPage(),
 
17
  RawDataPage,
18
  )
19
  from subpages.attention import AttentionPage
20
+ from subpages.hidden_states import HiddenStatesPage
21
  from subpages.inspect import InspectPage
22
+ from utils import classmap
23
 
24
  sts = st.sidebar
25
  st.set_page_config(
 
55
  def style(x):
56
  return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
59
  colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]
60
 
61
  color_legend_df = pd.DataFrame(
62
+ [classmap[l] for l in labels], columns=["label"], index=labels
63
  ).T
64
  st.sidebar.write(
65
  color_legend_df.T.style.apply(style, axis=0).set_properties(
 
72
  pages: list[Page] = [
73
  HomePage(),
74
  AttentionPage(),
75
+ HiddenStatesPage(),
76
  ProbingPage(),
77
  MetricsPage(),
78
  MisclassifiedPage(),
subpages/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
  from subpages.attention import AttentionPage
2
  from subpages.debug import DebugPage
3
- from subpages.embeddings import EmbeddingsPage
4
  from subpages.find_duplicates import FindDuplicatesPage
 
5
  from subpages.home import HomePage
6
  from subpages.inspect import InspectPage
7
  from subpages.losses import LossesPage
 
1
  from subpages.attention import AttentionPage
2
  from subpages.debug import DebugPage
 
3
  from subpages.find_duplicates import FindDuplicatesPage
4
+ from subpages.hidden_states import HiddenStatesPage
5
  from subpages.home import HomePage
6
  from subpages.inspect import InspectPage
7
  from subpages.losses import LossesPage
subpages/{embeddings.py β†’ hidden_states.py} RENAMED
@@ -28,8 +28,8 @@ def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
28
  return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
29
 
30
 
31
- class EmbeddingsPage(Page):
32
- name = "Embeddings"
33
  icon = "grid-3x3"
34
 
35
  def get_widget_defaults(self):
 
28
  return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
29
 
30
 
31
+ class HiddenStatesPage(Page):
32
+ name = "Hidden States"
33
  icon = "grid-3x3"
34
 
35
  def get_widget_defaults(self):
subpages/home.py CHANGED
@@ -3,11 +3,10 @@ import random
3
  from typing import Optional
4
 
5
  import streamlit as st
6
- from pandas import wide_to_long
7
 
8
  from data import get_data
9
  from subpages.page import Context, Page
10
- from utils import color_map_color
11
 
12
  _SENTENCE_ENCODER_MODEL = (
13
  "sentence-transformers/all-MiniLM-L6-v2",
@@ -53,7 +52,7 @@ class HomePage(Page):
53
 
54
  with st.expander("πŸ’‘", expanded=True):
55
  st.write(
56
- "**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further improvements."
57
  )
58
 
59
  col1, _, col2a, col2b = st.columns([1, 0.05, 0.15, 0.15])
@@ -91,7 +90,7 @@ class HomePage(Page):
91
  st.text_input(
92
  label="Encoder Model:",
93
  key="encoder_model_name",
94
- help="Path or name of the encoder to use",
95
  )
96
  ds_name = st.text_input(
97
  label="Dataset:",
@@ -136,8 +135,9 @@ class HomePage(Page):
136
  emojis = list(json.load(open("subpages/emoji-en-US.json")).keys())
137
  for label in labels:
138
  if f"icon_{label}" not in st.session_state:
139
- st.session_state[f"icon_{label}"] = "πŸ€—" # labels[label]
140
  st.selectbox(label, key=f"icon_{label}", options=emojis)
 
141
 
142
  # if st.button("Reset to defaults"):
143
  # st.session_state.update(**get_home_page_defaults())
 
3
  from typing import Optional
4
 
5
  import streamlit as st
 
6
 
7
  from data import get_data
8
  from subpages.page import Context, Page
9
+ from utils import classmap, color_map_color
10
 
11
  _SENTENCE_ENCODER_MODEL = (
12
  "sentence-transformers/all-MiniLM-L6-v2",
 
52
 
53
  with st.expander("πŸ’‘", expanded=True):
54
  st.write(
55
+ "**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further **improving both model AND dataset**."
56
  )
57
 
58
  col1, _, col2a, col2b = st.columns([1, 0.05, 0.15, 0.15])
 
90
  st.text_input(
91
  label="Encoder Model:",
92
  key="encoder_model_name",
93
+ help="Path or name of the encoder to use for duplicate detection",
94
  )
95
  ds_name = st.text_input(
96
  label="Dataset:",
 
135
  emojis = list(json.load(open("subpages/emoji-en-US.json")).keys())
136
  for label in labels:
137
  if f"icon_{label}" not in st.session_state:
138
+ st.session_state[f"icon_{label}"] = classmap[label]
139
  st.selectbox(label, key=f"icon_{label}", options=emojis)
140
+ classmap[label] = st.session_state[f"icon_{label}"]
141
 
142
  # if st.button("Reset to defaults"):
143
  # st.session_state.update(**get_home_page_defaults())
utils.py CHANGED
@@ -14,6 +14,19 @@ tokenizer_hash_funcs = {
14
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.has_mps else "cpu")
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def aggrid_interactive_table(df: pd.DataFrame) -> dict:
19
  """Creates an st-aggrid interactive table based on a dataframe.
@@ -159,18 +172,6 @@ def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
159
 
160
  def htmlify_labeled_example(example: pd.DataFrame) -> str:
161
  html = []
162
- classmap = {
163
- "O": "O",
164
- "PER": "πŸ™Ž",
165
- "person": "πŸ™Ž",
166
- "LOC": "🌎",
167
- "location": "🌎",
168
- "ORG": "🏀",
169
- "corporation": "🏀",
170
- "product": "πŸ“±",
171
- "creative": "🎷",
172
- "MISC": "🎷",
173
- }
174
 
175
  for _, row in example.iterrows():
176
  pred = row.preds.split("-")[1] if "-" in row.preds else "O"
 
14
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.has_mps else "cpu")
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ classmap = {
18
+ "O": "O",
19
+ "PER": "πŸ™Ž",
20
+ "person": "πŸ™Ž",
21
+ "LOC": "🌎",
22
+ "location": "🌎",
23
+ "ORG": "🏀",
24
+ "corporation": "🏀",
25
+ "product": "πŸ“±",
26
+ "creative": "🎷",
27
+ "MISC": "🎷",
28
+ }
29
+
30
 
31
  def aggrid_interactive_table(df: pd.DataFrame) -> dict:
32
  """Creates an st-aggrid interactive table based on a dataframe.
 
172
 
173
  def htmlify_labeled_example(example: pd.DataFrame) -> str:
174
  html = []
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  for _, row in example.iterrows():
177
  pred = row.preds.split("-")[1] if "-" in row.preds else "O"