Alexander Seifert commited on
Commit
9556889
1 Parent(s): 2918df9

improve docs

Browse files
Files changed (4) hide show
  1. src/data.py +53 -2
  2. src/load.py +14 -0
  3. src/subpages/attention.py +2 -2
  4. src/subpages/page.py +10 -0
src/data.py CHANGED
@@ -11,7 +11,19 @@ from utils import device, tokenizer_hash_funcs
11
 
12
 
13
  @st.cache(allow_output_mutation=True)
14
- def get_data(ds_name, config_name, split_name, split_sample_size) -> Dataset:
 
 
 
 
 
 
 
 
 
 
 
 
15
  ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(seed=0) # type: ignore
16
  split = ds[split_name].select(range(split_sample_size))
17
  return split
@@ -22,6 +34,14 @@ def get_data(ds_name, config_name, split_name, split_sample_size) -> Dataset:
22
  hash_funcs=tokenizer_hash_funcs,
23
  )
24
  def get_collator(tokenizer) -> DataCollatorForTokenClassification:
 
 
 
 
 
 
 
 
25
  return DataCollatorForTokenClassification(tokenizer)
26
 
27
 
@@ -70,10 +90,29 @@ def tokenize_and_align_labels(examples, tokenizer):
70
 
71
 
72
  def stringify_ner_tags(batch, tags):
 
 
 
 
 
 
 
 
 
73
  return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}
74
 
75
 
76
- def encode_dataset(split, tokenizer):
 
 
 
 
 
 
 
 
 
 
77
  tags = split.features["ner_tags"].feature
78
  split = split.map(partial(stringify_ner_tags, tags=tags), batched=True)
79
  remove_columns = split.column_names
@@ -120,6 +159,18 @@ def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
120
 
121
 
122
  def get_split_df(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
 
 
 
 
 
 
 
 
 
 
 
 
123
  split_encoded = split_encoded.map(
124
  partial(
125
  forward_pass_with_label,
 
11
 
12
 
13
  @st.cache(allow_output_mutation=True)
14
+ def get_data(ds_name: str, config_name: str, split_name: str, split_sample_size: int) -> Dataset:
15
+ """Loads dataset from the HF hub (if not already loaded) and returns a Dataset object.
16
+ Uses datasets.load_dataset to load the dataset (see its documentation for additional details).
17
+
18
+ Args:
19
+ ds_name (str): Path or name of the dataset.
20
+ config_name (str): Name of the dataset configuration.
21
+ split_name (str): Which split of the data to load.
22
+ split_sample_size (int): The number of examples to load from the split.
23
+
24
+ Returns:
25
+ Dataset: A Dataset object.
26
+ """
27
  ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(seed=0) # type: ignore
28
  split = ds[split_name].select(range(split_sample_size))
29
  return split
 
34
  hash_funcs=tokenizer_hash_funcs,
35
  )
36
  def get_collator(tokenizer) -> DataCollatorForTokenClassification:
37
+ """Data collator that will dynamically pad the inputs received, as well as the labels.
38
+
39
+ Args:
40
+ tokenizer ([PreTrainedTokenizer] or [PreTrainedTokenizerFast]): The tokenizer used for encoding the data.
41
+
42
+ Returns:
43
+ DataCollatorForTokenClassification: The DataCollatorForTokenClassification object.
44
+ """
45
  return DataCollatorForTokenClassification(tokenizer)
46
 
47
 
 
90
 
91
 
92
  def stringify_ner_tags(batch, tags):
93
+ """Stringifies a dataset batch's NER tags.
94
+
95
+ Args:
96
+ batch (_type_): _description_
97
+ tags (_type_): _description_
98
+
99
+ Returns:
100
+ _type_: _description_
101
+ """
102
  return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}
103
 
104
 
105
+ def encode_dataset(split: Dataset, tokenizer):
106
+ """Encodes a dataset split.
107
+
108
+ Args:
109
+ split (Dataset): A Dataset object.
110
+ tokenizer: A PreTrainedTokenizer object.
111
+
112
+ Returns:
113
+ Dataset: A Dataset object with the encoded inputs.
114
+ """
115
+
116
  tags = split.features["ner_tags"].feature
117
  split = split.map(partial(stringify_ner_tags, tags=tags), batched=True)
118
  remove_columns = split.column_names
 
159
 
160
 
161
  def get_split_df(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
162
+ """Turns a Dataset into a pandas dataframe.
163
+
164
+ Args:
165
+ split_encoded (Dataset): _description_
166
+ model (_type_): _description_
167
+ tokenizer (_type_): _description_
168
+ collator (_type_): _description_
169
+ tags (_type_): _description_
170
+
171
+ Returns:
172
+ pd.DataFrame: _description_
173
+ """
174
  split_encoded = split_encoded.map(
175
  partial(
176
  forward_pass_with_label,
src/load.py CHANGED
@@ -39,6 +39,20 @@ def load_context(
39
  split_sample_size: int,
40
  **kw_args,
41
  ) -> Context:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
44
  encoder_model_name=encoder_model_name,
 
39
  split_sample_size: int,
40
  **kw_args,
41
  ) -> Context:
42
+ """Utility method loading (almost) everything we need for the application.
43
+ This exists just because we want to cache the results of this function.
44
+
45
+ Args:
46
+ encoder_model_name (str): Name of the sentence encoder to load.
47
+ model_name (str): Name of the NER model to load.
48
+ ds_name (str): Dataset name or path.
49
+ ds_config_name (str): Dataset config name.
50
+ ds_split_name (str): Dataset split name.
51
+ split_sample_size (int): Number of examples to load from the split.
52
+
53
+ Returns:
54
+ Context: An object containing everything we need for the application.
55
+ """
56
 
57
  sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
58
  encoder_model_name=encoder_model_name,
src/subpages/attention.py CHANGED
@@ -80,7 +80,7 @@ JS_TEMPLATE = """requirejs(['basic', 'ecco'], function(basic, ecco){{
80
 
81
 
82
  @st.cache(allow_output_mutation=True)
83
- def load_ecco_model():
84
  model_config = {
85
  "embedding": "embeddings.word_embeddings",
86
  "type": "mlm",
@@ -115,7 +115,7 @@ class AttentionPage(Page):
115
  "A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model."
116
  )
117
 
118
- lm = load_ecco_model()
119
 
120
  col1, _, col2 = st.columns([1.5, 0.5, 4])
121
  with col1:
 
80
 
81
 
82
  @st.cache(allow_output_mutation=True)
83
+ def _load_ecco_model():
84
  model_config = {
85
  "embedding": "embeddings.word_embeddings",
86
  "type": "mlm",
 
115
  "A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model."
116
  )
117
 
118
+ lm = _load_ecco_model()
119
 
120
  col1, _, col2 = st.columns([1.5, 0.5, 4])
121
  with col1:
src/subpages/page.py CHANGED
@@ -10,6 +10,8 @@ from transformers import AutoTokenizer # type: ignore
10
 
11
  @dataclass
12
  class Context:
 
 
13
  model: AutoModelForSequenceClassification
14
  tokenizer: AutoTokenizer
15
  sentence_encoder: SentenceTransformer
@@ -27,11 +29,19 @@ class Context:
27
 
28
 
29
  class Page:
 
 
30
  name: str
31
  icon: str
32
 
33
  def get_widget_defaults(self):
 
 
 
 
 
34
  return {}
35
 
36
  def render(self, context):
 
37
  ...
 
10
 
11
  @dataclass
12
  class Context:
13
+ """This object facilitates passing around the applications state between different pages."""
14
+
15
  model: AutoModelForSequenceClassification
16
  tokenizer: AutoTokenizer
17
  sentence_encoder: SentenceTransformer
 
29
 
30
 
31
  class Page:
32
+ """Base class for all pages."""
33
+
34
  name: str
35
  icon: str
36
 
37
  def get_widget_defaults(self):
38
+ """This function holds the default settings for all the page's widgets.
39
+
40
+ Returns:
41
+ dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.
42
+ """
43
  return {}
44
 
45
  def render(self, context):
46
+ """This function renders the page."""
47
  ...