Alexander Seifert commited on
Commit
554bac5
1 Parent(s): 8778b89

improve docs

Browse files
src/data.py CHANGED
@@ -46,7 +46,16 @@ def get_collator(tokenizer) -> DataCollatorForTokenClassification:
46
  return DataCollatorForTokenClassification(tokenizer)
47
 
48
 
49
- def create_word_ids_from_tokens(tokenizer, input_ids: list[int]):
 
 
 
 
 
 
 
 
 
50
  word_ids = []
51
  wid = -1
52
  tokens = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]
@@ -65,16 +74,27 @@ def create_word_ids_from_tokens(tokenizer, input_ids: list[int]):
65
  return word_ids
66
 
67
 
68
- def tokenize_and_align_labels(examples, tokenizer):
69
- tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
 
 
 
 
 
 
 
 
 
70
  labels = []
71
  wids = []
72
 
73
- for idx, label in enumerate(examples["ner_tags"]):
74
  try:
75
  word_ids = tokenized_inputs.word_ids(batch_index=idx)
76
  except ValueError:
77
- word_ids = create_word_ids_from_tokens(tokenizer, tokenized_inputs["input_ids"][idx])
 
 
78
  previous_word_idx = None
79
  label_ids = []
80
  for word_idx in word_ids:
@@ -119,7 +139,7 @@ def encode_dataset(split: Dataset, tokenizer):
119
  remove_columns = split.column_names
120
  ids = split["id"]
121
  split = split.map(
122
- partial(tokenize_and_align_labels, tokenizer=tokenizer),
123
  batched=True,
124
  remove_columns=remove_columns,
125
  )
@@ -128,6 +148,18 @@ def encode_dataset(split: Dataset, tokenizer):
128
 
129
 
130
  def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
131
  # Convert dict of lists to list of dicts suitable for data collator
132
  features = [dict(zip(batch, t)) for t in zip(*batch.values())]
133
 
@@ -159,19 +191,20 @@ def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
159
  return {"losses": loss, "preds": preds, "hidden_states": hidden_states}
160
 
161
 
162
- def get_split_df(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
163
- """Turns a Dataset into a pandas dataframe.
164
 
165
  Args:
166
- split_encoded (Dataset): _description_
167
- model (_type_): _description_
168
- tokenizer (_type_): _description_
169
- collator (_type_): _description_
170
- tags (_type_): _description_
171
 
172
  Returns:
173
- pd.DataFrame: _description_
174
  """
 
175
  split_encoded = split_encoded.map(
176
  partial(
177
  forward_pass_with_label,
 
46
  return DataCollatorForTokenClassification(tokenizer)
47
 
48
 
49
+ def create_word_ids_from_input_ids(tokenizer, input_ids: list[int]) -> list[int]:
50
+ """Takes a list of input_ids and return corresponding word_ids
51
+
52
+ Args:
53
+ tokenizer: The tokenizer that was used to obtain the input ids.
54
+ input_ids (list[int]): List of token ids.
55
+
56
+ Returns:
57
+ list[int]: Word ids corresponding to the input ids.
58
+ """
59
  word_ids = []
60
  wid = -1
61
  tokens = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]
 
74
  return word_ids
75
 
76
 
77
+ def tokenize(batch, tokenizer) -> dict:
78
+ """Tokenizes a batch of examples.
79
+
80
+ Args:
81
+ batch: The examples to tokenize
82
+ tokenizer: The tokenizer to use
83
+
84
+ Returns:
85
+ dict: The tokenized batch
86
+ """
87
+ tokenized_inputs = tokenizer(batch["tokens"], truncation=True, is_split_into_words=True)
88
  labels = []
89
  wids = []
90
 
91
+ for idx, label in enumerate(batch["ner_tags"]):
92
  try:
93
  word_ids = tokenized_inputs.word_ids(batch_index=idx)
94
  except ValueError:
95
+ word_ids = create_word_ids_from_input_ids(
96
+ tokenizer, tokenized_inputs["input_ids"][idx]
97
+ )
98
  previous_word_idx = None
99
  label_ids = []
100
  for word_idx in word_ids:
 
139
  remove_columns = split.column_names
140
  ids = split["id"]
141
  split = split.map(
142
+ partial(tokenize, tokenizer=tokenizer),
143
  batched=True,
144
  remove_columns=remove_columns,
145
  )
 
148
 
149
 
150
  def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
151
+ """Runs the forward pass for a batch of examples.
152
+
153
+ Args:
154
+ batch: The batch to process
155
+ model: The model to process the batch with
156
+ collator: A data collator
157
+ num_classes (int): Number of classes
158
+
159
+ Returns:
160
+ dict: a dictionary containing `losses`, `preds` and `hidden_states`
161
+ """
162
+
163
  # Convert dict of lists to list of dicts suitable for data collator
164
  features = [dict(zip(batch, t)) for t in zip(*batch.values())]
165
 
 
191
  return {"losses": loss, "preds": preds, "hidden_states": hidden_states}
192
 
193
 
194
+ def predict(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
195
+ """Generates predictions for a given dataset split and returns the results as a dataframe.
196
 
197
  Args:
198
+ split_encoded (Dataset): The dataset to process
199
+ model: The model to process the dataset with
200
+ tokenizer: The tokenizer to process the dataset with
201
+ collator: The data collator to use
202
+ tags: The tags used in the dataset
203
 
204
  Returns:
205
+ pd.DataFrame: A dataframe containing token-level predictions.
206
  """
207
+
208
  split_encoded = split_encoded.map(
209
  partial(
210
  forward_pass_with_label,
src/load.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
  import streamlit as st
5
  from datasets import Dataset # type: ignore
6
 
7
- from src.data import encode_dataset, get_collator, get_data, get_split_df
8
  from src.model import get_encoder, get_model, get_tokenizer
9
  from src.subpages import Context
10
  from src.utils import align_sample, device, explode_df
@@ -68,7 +68,7 @@ def load_context(
68
  split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
69
 
70
  # transform into dataframe
71
- df = get_split_df(split_encoded, model, tokenizer, collator, tags)
72
  df["word_ids"] = word_ids
73
  df["ids"] = ids
74
 
 
4
  import streamlit as st
5
  from datasets import Dataset # type: ignore
6
 
7
+ from src.data import encode_dataset, get_collator, get_data, predict
8
  from src.model import get_encoder, get_model, get_tokenizer
9
  from src.subpages import Context
10
  from src.utils import align_sample, device, explode_df
 
68
  split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
69
 
70
  # transform into dataframe
71
+ df = predict(split_encoded, model, tokenizer, collator, tags)
72
  df["word_ids"] = word_ids
73
  df["ids"] = ids
74
 
src/subpages/attention.py CHANGED
@@ -7,7 +7,7 @@ from streamlit.components.v1 import html
7
 
8
  from src.subpages.page import Context, Page # type: ignore
9
 
10
- SETUP_HTML = """
11
  <script src="https://requirejs.org/docs/release/2.3.6/minified/require.js"></script>
12
  <script>
13
  var ecco_url = 'https://storage.googleapis.com/ml-intro/ecco/'
@@ -70,17 +70,6 @@ SETUP_HTML = """
70
  <div id="basic"></div>
71
  """
72
 
73
- JS_TEMPLATE = """requirejs(['basic', 'ecco'], function(basic, ecco){{
74
- const viz_id = basic.init()
75
-
76
- ecco.interactiveTokensAndFactorSparklines(viz_id, {}, {{
77
- 'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}}
78
- }}
79
- }})
80
- }}, function (err) {{
81
- console.log(err);
82
- }})"""
83
-
84
 
85
  @st.cache(allow_output_mutation=True)
86
  def _load_ecco_model():
@@ -160,10 +149,10 @@ class AttentionPage(Page):
160
  output = lm(inputs)
161
  nmf = output.run_nmf(n_components=n_components, from_layer=from_layer, to_layer=to_layer)
162
  data = nmf.explore(returnData=True)
163
- JS_TEMPLATE = f"""<script>requirejs(['basic', 'ecco'], function(basic, ecco){{
164
  const viz_id = basic.init()
165
  ecco.interactiveTokensAndFactorSparklines(viz_id, {data}, {{ 'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}} }} }})
166
  }}, function (err) {{
167
  console.log(err);
168
  }})</script>"""
169
- html(SETUP_HTML + JS_TEMPLATE, height=800, scrolling=True)
 
7
 
8
  from src.subpages.page import Context, Page # type: ignore
9
 
10
+ _SETUP_HTML = """
11
  <script src="https://requirejs.org/docs/release/2.3.6/minified/require.js"></script>
12
  <script>
13
  var ecco_url = 'https://storage.googleapis.com/ml-intro/ecco/'
 
70
  <div id="basic"></div>
71
  """
72
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  @st.cache(allow_output_mutation=True)
75
  def _load_ecco_model():
 
149
  output = lm(inputs)
150
  nmf = output.run_nmf(n_components=n_components, from_layer=from_layer, to_layer=to_layer)
151
  data = nmf.explore(returnData=True)
152
+ _JS_TEMPLATE = f"""<script>requirejs(['basic', 'ecco'], function(basic, ecco){{
153
  const viz_id = basic.init()
154
  ecco.interactiveTokensAndFactorSparklines(viz_id, {data}, {{ 'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}} }} }})
155
  }}, function (err) {{
156
  console.log(err);
157
  }})</script>"""
158
+ html(_SETUP_HTML + _JS_TEMPLATE, height=800, scrolling=True)
src/subpages/hidden_states.py CHANGED
@@ -10,7 +10,19 @@ from src.subpages.page import Context, Page
10
 
11
 
12
  @st.cache
13
- def reduce_dim_svd(X, n_iter, random_state=42):
 
 
 
 
 
 
 
 
 
 
 
 
14
  from sklearn.decomposition import TruncatedSVD
15
 
16
  svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
@@ -19,6 +31,17 @@ def reduce_dim_svd(X, n_iter, random_state=42):
19
 
20
  @st.cache
21
  def reduce_dim_pca(X, random_state=42):
 
 
 
 
 
 
 
 
 
 
 
22
  from sklearn.decomposition import PCA
23
 
24
  return PCA(n_components=2, random_state=random_state).fit_transform(X)
@@ -26,6 +49,19 @@ def reduce_dim_pca(X, random_state=42):
26
 
27
  @st.cache
28
  def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  from umap import UMAP
30
 
31
  return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
 
10
 
11
 
12
  @st.cache
13
+ def reduce_dim_svd(X, n_iter: int, random_state=42):
14
+ """Dimensionality reduction using truncated SVD (aka LSA).
15
+
16
+ This transformer performs linear dimensionality reduction by means of truncated singular value decomposition (SVD). Contrary to PCA, this estimator does not center the data before computing the singular value decomposition. This means it can work with sparse matrices efficiently.
17
+
18
+ Args:
19
+ X: Training data
20
+ n_iter (int): Desired dimensionality of output data. Must be strictly less than the number of features.
21
+ random_state (int, optional): Used during randomized svd. Pass an int for reproducible results across multiple function calls. Defaults to 42.
22
+
23
+ Returns:
24
+ ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
25
+ """
26
  from sklearn.decomposition import TruncatedSVD
27
 
28
  svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
 
31
 
32
  @st.cache
33
  def reduce_dim_pca(X, random_state=42):
34
+ """Principal component analysis (PCA).
35
+
36
+ Linear dimensionality reduction using Singular Value Decomposition of the data to project it to a lower dimensional space. The input data is centered but not scaled for each feature before applying the SVD.
37
+
38
+ Args:
39
+ X: Training data
40
+ random_state (int, optional): Used when the 'arpack' or 'randomized' solvers are used. Pass an int for reproducible results across multiple function calls.
41
+
42
+ Returns:
43
+ ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
44
+ """
45
  from sklearn.decomposition import PCA
46
 
47
  return PCA(n_components=2, random_state=random_state).fit_transform(X)
 
49
 
50
  @st.cache
51
  def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
52
+ """Uniform Manifold Approximation and Projection
53
+
54
+ Finds a low dimensional embedding of the data that approximates an underlying manifold.
55
+
56
+ Args:
57
+ X: Training data
58
+ n_neighbors (int, optional): The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. In general values should be in the range 2 to 100. Defaults to 5.
59
+ min_dist (float, optional): The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. The value should be set relative to the `spread` value, which determines the scale at which embedded points will be spread out. Defaults to 0.1.
60
+ metric (str, optional): The metric to use to compute distances in high dimensional space (see UMAP docs for options). Defaults to "euclidean".
61
+
62
+ Returns:
63
+ ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
64
+ """
65
  from umap import UMAP
66
 
67
  return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
src/utils.py CHANGED
@@ -34,6 +34,7 @@ classmap = {
34
 
35
  def aggrid_interactive_table(df: pd.DataFrame) -> dict:
36
  """Creates an st-aggrid interactive table based on a dataframe.
 
37
  Args:
38
  df (pd.DataFrame]): Source dataframe
39
  Returns:
@@ -60,6 +61,8 @@ def aggrid_interactive_table(df: pd.DataFrame) -> dict:
60
 
61
 
62
  def explode_df(df: pd.DataFrame) -> pd.DataFrame:
 
 
63
  df_tokens = df.apply(pd.Series.explode)
64
  if "losses" in df.columns:
65
  df_tokens["losses"] = df_tokens["losses"].astype(float)
@@ -67,7 +70,7 @@ def explode_df(df: pd.DataFrame) -> pd.DataFrame:
67
 
68
 
69
  def align_sample(row: pd.Series):
70
- """Use word_ids to align all lists in a sample."""
71
 
72
  columns = row.axes[0].to_list()
73
  indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]]
@@ -113,7 +116,17 @@ def align_sample(row: pd.Series):
113
  hash_funcs=tokenizer_hash_funcs,
114
  )
115
  def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
116
- """Create an (exploded) DataFrame with the predicted labels and probabilities."""
 
 
 
 
 
 
 
 
 
 
117
 
118
  tokens = tokenizer(text).tokens()
119
  tokenized = tokenizer(text, return_tensors="pt")
@@ -137,21 +150,31 @@ def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
137
  return explode_df(merged_df).reset_index().drop(columns=["index"])
138
 
139
 
140
- def get_bg_color(label):
 
141
  return st.session_state[f"color_{label}"]
142
 
143
 
144
- def get_fg_color(hex_color: str) -> str:
145
- """Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/"""
146
- r = int(hex_color[1:3], 16)
147
- g = int(hex_color[3:5], 16)
148
- b = int(hex_color[5:7], 16)
 
 
 
 
 
 
 
 
 
149
  yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
150
  return "black" if (yiq >= 128) else "white"
151
 
152
 
153
  def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
154
- """Colorize the errors in the dataframe."""
155
 
156
  def colorize_row(row):
157
  return [
@@ -175,6 +198,14 @@ def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
175
 
176
 
177
  def htmlify_labeled_example(example: pd.DataFrame) -> str:
 
 
 
 
 
 
 
 
178
  html = []
179
 
180
  for _, row in example.iterrows():
@@ -215,18 +246,8 @@ def htmlify_labeled_example(example: pd.DataFrame) -> str:
215
  return " ".join(html)
216
 
217
 
218
- def htmlify_example(example: pd.DataFrame) -> str:
219
- corr_html = " ".join(
220
- [
221
- f", {row.tokens}" if row.labels == "B-COMMA" else row.tokens
222
- for _, row in example.iterrows()
223
- ]
224
- ).strip()
225
- return f"<em>{corr_html}</em>"
226
-
227
-
228
  def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str:
229
- """Turn a value into a color using a color map."""
230
  norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
231
  cmap = cm.get_cmap(cmap_name) # PiYG
232
  rgba = cmap(norm(abs(value)))
 
34
 
35
  def aggrid_interactive_table(df: pd.DataFrame) -> dict:
36
  """Creates an st-aggrid interactive table based on a dataframe.
37
+
38
  Args:
39
  df (pd.DataFrame]): Source dataframe
40
  Returns:
 
61
 
62
 
63
  def explode_df(df: pd.DataFrame) -> pd.DataFrame:
64
+ """Takes a dataframe and explodes all the fields."""
65
+
66
  df_tokens = df.apply(pd.Series.explode)
67
  if "losses" in df.columns:
68
  df_tokens["losses"] = df_tokens["losses"].astype(float)
 
70
 
71
 
72
  def align_sample(row: pd.Series):
73
+ """Uses word_ids to align all lists in a sample."""
74
 
75
  columns = row.axes[0].to_list()
76
  indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]]
 
116
  hash_funcs=tokenizer_hash_funcs,
117
  )
118
  def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
119
+ """Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities.
120
+
121
+ Args:
122
+ text (str): The text to be processed
123
+ tokenizer: Tokenizer to use
124
+ model (_type_): Model to use
125
+ device (torch.device): The device we want pytorch to use for its calcultaions.
126
+
127
+ Returns:
128
+ pd.DataFrame: A data frame holding the tagged text.
129
+ """
130
 
131
  tokens = tokenizer(text).tokens()
132
  tokenized = tokenizer(text, return_tensors="pt")
 
150
  return explode_df(merged_df).reset_index().drop(columns=["index"])
151
 
152
 
153
+ def get_bg_color(label: str):
154
+ """Retrieves a label's color from the session state."""
155
  return st.session_state[f"color_{label}"]
156
 
157
 
158
+ def get_fg_color(bg_color_hex: str) -> str:
159
+ """Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast.
160
+
161
+ Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/
162
+
163
+ Args:
164
+ bg_color_hex (str): The background color given as a HEX stirng.
165
+
166
+ Returns:
167
+ str: Either "black" or "white".
168
+ """
169
+ r = int(bg_color_hex[1:3], 16)
170
+ g = int(bg_color_hex[3:5], 16)
171
+ b = int(bg_color_hex[5:7], 16)
172
  yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
173
  return "black" if (yiq >= 128) else "white"
174
 
175
 
176
  def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
177
+ """Colorizes the errors in the dataframe."""
178
 
179
  def colorize_row(row):
180
  return [
 
198
 
199
 
200
  def htmlify_labeled_example(example: pd.DataFrame) -> str:
201
+ """Builds an HTML (string) representation of a single example.
202
+
203
+ Args:
204
+ example (pd.DataFrame): The example to process.
205
+
206
+ Returns:
207
+ str: An HTML string representation of a single example.
208
+ """
209
  html = []
210
 
211
  for _, row in example.iterrows():
 
246
  return " ".join(html)
247
 
248
 
 
 
 
 
 
 
 
 
 
 
249
  def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str:
250
+ """Turns a value into a color using a color map."""
251
  norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
252
  cmap = cm.get_cmap(cmap_name) # PiYG
253
  rgba = cmap(norm(abs(value)))