Ezi Ozoani commited on
Commit
b69fb1e
1 Parent(s): 616ba36

test upload

Browse files
Files changed (44) hide show
  1. .ipynb_checkpoints/app (2)-checkpoint.py +296 -0
  2. Scripts/run.sh +112 -0
  3. app.py +256 -0
  4. data_measurements/__init__.py +0 -0
  5. data_measurements/__pycache__/__init__.cpython-310.pyc +0 -0
  6. data_measurements/__pycache__/__init__.cpython-311.pyc +0 -0
  7. data_measurements/__pycache__/dataset_statistics.cpython-310.pyc +0 -0
  8. data_measurements/__pycache__/dataset_statistics.cpython-311.pyc +0 -0
  9. data_measurements/__pycache__/dataset_utils.cpython-310.pyc +0 -0
  10. data_measurements/__pycache__/dataset_utils.cpython-311.pyc +0 -0
  11. data_measurements/__pycache__/embeddings.cpython-310.pyc +0 -0
  12. data_measurements/__pycache__/embeddings.cpython-311.pyc +0 -0
  13. data_measurements/__pycache__/npmi.cpython-310.pyc +0 -0
  14. data_measurements/__pycache__/npmi.cpython-311.pyc +0 -0
  15. data_measurements/__pycache__/streamlit_utils.cpython-310.pyc +0 -0
  16. data_measurements/__pycache__/streamlit_utils.cpython-311.pyc +0 -0
  17. data_measurements/__pycache__/zipf.cpython-310.pyc +0 -0
  18. data_measurements/__pycache__/zipf.cpython-311.pyc +0 -0
  19. data_measurements/_pycache_/__init__.cpython-311.pyc +0 -0
  20. data_measurements/_pycache_/__init__.cpython-37.pyc +0 -0
  21. data_measurements/_pycache_/dataset_statistics.cpython-311.pyc +0 -0
  22. data_measurements/_pycache_/dataset_statistics.cpython-37.pyc +0 -0
  23. data_measurements/_pycache_/dataset_utils.cpython-311.pyc +0 -0
  24. data_measurements/_pycache_/dataset_utils.cpython-37.pyc +0 -0
  25. data_measurements/_pycache_/embeddings.cpython-311.pyc +0 -0
  26. data_measurements/_pycache_/embeddings.cpython-37.pyc +0 -0
  27. data_measurements/_pycache_/npmi.cpython-311.pyc +0 -0
  28. data_measurements/_pycache_/npmi.cpython-37.pyc +0 -0
  29. data_measurements/_pycache_/streamlit_utils.cpython-311.pyc +0 -0
  30. data_measurements/_pycache_/zipf.cpython-311.pyc +0 -0
  31. data_measurements/_pycache_/zipf.cpython-37.pyc +0 -0
  32. data_measurements/dataset_statistics.py +1223 -0
  33. data_measurements/dataset_utils.py +296 -0
  34. data_measurements/embeddings.py +550 -0
  35. data_measurements/npmi.py +254 -0
  36. data_measurements/streamlit_utils.py +498 -0
  37. data_measurements/zipf.py +247 -0
  38. log_files/app.log +59 -0
  39. log_files/dataset_statistics.log +4 -0
  40. log_files/npmi.log +0 -0
  41. log_files/zipf.log +0 -0
  42. run.sh +110 -0
  43. run_data_measurements.py +296 -0
  44. temp.jsonl +0 -0
.ipynb_checkpoints/app (2)-checkpoint.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from os import mkdir
17
+ from os.path import exists, isdir
18
+ from pathlib import Path
19
+
20
+ # #! pip install streamlit
21
+ import streamlit as st
22
+
23
+ # +
24
+ # #! pip install datasets
25
+ # #! pip install powerlaw
26
+ # -
27
+
28
+ from data_measurements import dataset_statistics, dataset_utils
29
+ from data_measurements import streamlit_utils as st_utils
30
+
31
+ logs = logging.getLogger(__name__)
32
+ logs.setLevel(logging.WARNING)
33
+ logs.propagate = False
34
+
35
+ if not logs.handlers:
36
+
37
+ Path('./log_files').mkdir(exist_ok=True)
38
+
39
+ # Logging info to log file
40
+ file = logging.FileHandler("./log_files/app.log")
41
+ fileformat = logging.Formatter("%(asctime)s:%(message)s")
42
+ file.setLevel(logging.INFO)
43
+ file.setFormatter(fileformat)
44
+
45
+ # Logging debug messages to stream
46
+ stream = logging.StreamHandler()
47
+ streamformat = logging.Formatter("[data_measurements_tool] %(message)s")
48
+ stream.setLevel(logging.WARNING)
49
+ stream.setFormatter(streamformat)
50
+
51
+ logs.addHandler(file)
52
+ logs.addHandler(stream)
53
+
54
+ st.set_page_config(
55
+ page_title="Demo to showcase dataset metrics",
56
+ page_icon="https://huggingface.co/front/assets/huggingface_logo.svg",
57
+ layout="wide",
58
+ initial_sidebar_state="auto",
59
+ )
60
+
61
+ # colorblind-friendly colors
62
+ colors = [
63
+ "#332288",
64
+ "#117733",
65
+ "#882255",
66
+ "#AA4499",
67
+ "#CC6677",
68
+ "#44AA99",
69
+ "#DDCC77",
70
+ "#88CCEE",
71
+ ]
72
+
73
+ CACHE_DIR = dataset_utils.CACHE_DIR
74
+ # String names we are using (not coming from the stored dataset).
75
+ OUR_TEXT_FIELD = dataset_utils.OUR_TEXT_FIELD
76
+ OUR_LABEL_FIELD = dataset_utils.OUR_LABEL_FIELD
77
+ TOKENIZED_FIELD = dataset_utils.TOKENIZED_FIELD
78
+ EMBEDDING_FIELD = dataset_utils.EMBEDDING_FIELD
79
+ LENGTH_FIELD = dataset_utils.LENGTH_FIELD
80
+ # TODO: Allow users to specify this.
81
+ _MIN_VOCAB_COUNT = 10
82
+ _SHOW_TOP_N_WORDS = 10
83
+
84
+
85
+ @st.cache(
86
+ hash_funcs={
87
+ dataset_statistics.DatasetStatisticsCacheClass: lambda dstats: dstats.cache_path
88
+ },
89
+ allow_output_mutation=True,
90
+ )
91
+ def load_or_prepare(ds_args, show_embeddings, use_cache=False):
92
+ """
93
+ Takes the dataset arguments from the GUI and uses them to load a dataset from the Hub or, if
94
+ a cache for those arguments is available, to load it from the cache.
95
+ Args:
96
+ ds_args (dict): the dataset arguments defined via the streamlit app GUI
97
+ show_embeddings (Bool): whether embeddings should we loaded and displayed for this dataset
98
+ use_cache (Bool) : whether the cache is used by default or not
99
+ Returns:
100
+ dstats: the computed dataset statistics (from the dataset_statistics class)
101
+ """
102
+ if not isdir(CACHE_DIR):
103
+ logs.warning("Creating cache")
104
+ # We need to preprocess everything.
105
+ # This should eventually all go into a prepare_dataset CLI
106
+ mkdir(CACHE_DIR)
107
+ if use_cache:
108
+ logs.warning("Using cache")
109
+ dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
110
+ logs.warning("Loading dataset")
111
+ dstats.load_or_prepare_dataset()
112
+ logs.warning("Loading labels")
113
+ dstats.load_or_prepare_labels()
114
+ logs.warning("Loading text lengths")
115
+ dstats.load_or_prepare_text_lengths()
116
+ logs.warning("Loading duplicates")
117
+ dstats.load_or_prepare_text_duplicates()
118
+ logs.warning("Loading vocabulary")
119
+ dstats.load_or_prepare_vocab()
120
+ logs.warning("Loading general statistics...")
121
+ dstats.load_or_prepare_general_stats()
122
+ if show_embeddings:
123
+ logs.warning("Loading Embeddings")
124
+ dstats.load_or_prepare_embeddings()
125
+ logs.warning("Loading nPMI")
126
+ try:
127
+ dstats.load_or_prepare_npmi()
128
+ except:
129
+ logs.warning("Missing a cache for npmi")
130
+ logs.warning("Loading Zipf")
131
+ dstats.load_or_prepare_zipf()
132
+ return dstats
133
+
134
+ @st.cache(
135
+ hash_funcs={
136
+ dataset_statistics.DatasetStatisticsCacheClass: lambda dstats: dstats.cache_path
137
+ },
138
+ allow_output_mutation=True,
139
+ )
140
+ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
141
+ """
142
+ Loader specifically for the widgets used in the app.
143
+ Args:
144
+ ds_args:
145
+ show_embeddings:
146
+ use_cache:
147
+
148
+ Returns:
149
+
150
+ """
151
+
152
+ if use_cache:
153
+ logs.warning("Using cache")
154
+ if True:
155
+ #try:
156
+ dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
157
+ # Don't recalculate; we're live
158
+ dstats.set_deployment(True)
159
+ # checks whether the cache_dir exists in deployment mode
160
+ # creates cache_dir if not and if in development mode
161
+ cache_dir_exists = dstats.check_cache_dir()
162
+ #except:
163
+ # logs.warning("We're screwed")
164
+ if cache_dir_exists:
165
+ try:
166
+ # We need to have the text_dset loaded for further load_or_prepare
167
+ dstats.load_or_prepare_dataset()
168
+ except:
169
+ logs.warning("Missing a cache for load or prepare dataset")
170
+ try:
171
+ # Header widget
172
+ dstats.load_or_prepare_dset_peek()
173
+ except:
174
+ logs.warning("Missing a cache for dset peek")
175
+ try:
176
+ # General stats widget
177
+ dstats.load_or_prepare_general_stats()
178
+ except:
179
+ logs.warning("Missing a cache for general stats")
180
+ try:
181
+ # Labels widget
182
+ dstats.load_or_prepare_labels()
183
+ except:
184
+ logs.warning("Missing a cache for prepare labels")
185
+ try:
186
+ # Text lengths widget
187
+ dstats.load_or_prepare_text_lengths()
188
+ except:
189
+ logs.warning("Missing a cache for text lengths")
190
+ if show_embeddings:
191
+ try:
192
+ # Embeddings widget
193
+ dstats.load_or_prepare_embeddings()
194
+ except:
195
+ logs.warning("Missing a cache for embeddings")
196
+ try:
197
+ dstats.load_or_prepare_text_duplicates()
198
+ except:
199
+ logs.warning("Missing a cache for text duplicates")
200
+ try:
201
+ dstats.load_or_prepare_npmi()
202
+ except:
203
+ logs.warning("Missing a cache for npmi")
204
+ try:
205
+ dstats.load_or_prepare_zipf()
206
+ except:
207
+ logs.warning("Missing a cache for zipf")
208
+ return dstats, cache_dir_exists
209
+
210
+ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id):
211
+ """
212
+ Function for displaying the elements in the right column of the streamlit app.
213
+ Args:
214
+ ds_name_to_dict (dict): the dataset name and options in dictionary form
215
+ show_embeddings (Bool): whether embeddings should we loaded and displayed for this dataset
216
+ column_id (str): what column of the dataset the analysis is done on
217
+ Returns:
218
+ The function displays the information using the functions defined in the st_utils class.
219
+ """
220
+ # Note that at this point we assume we can use cache; default value is True.
221
+ # start showing stuff
222
+ title_str = f"### Showing{column_id}: {dstats.dset_name} - {dstats.dset_config} - {dstats.split_name} - {'-'.join(dstats.text_field)}"
223
+ st.markdown(title_str)
224
+ logs.info("showing header")
225
+ st_utils.expander_header(dstats, ds_name_to_dict, column_id)
226
+ logs.info("showing general stats")
227
+ st_utils.expander_general_stats(dstats, column_id)
228
+ st_utils.expander_label_distribution(dstats.fig_labels, column_id)
229
+ st_utils.expander_text_lengths(dstats, column_id)
230
+ st_utils.expander_text_duplicates(dstats, column_id)
231
+ # Uses an interaction; handled a bit differently than other widgets.
232
+ logs.info("showing npmi widget")
233
+ st_utils.npmi_widget(dstats.npmi_stats, _MIN_VOCAB_COUNT, column_id)
234
+ logs.info("showing zipf")
235
+ st_utils.expander_zipf(dstats.z, dstats.zipf_fig, column_id)
236
+ if show_embeddings:
237
+ st_utils.expander_text_embeddings(
238
+ dstats.text_dset,
239
+ dstats.fig_tree,
240
+ dstats.node_list,
241
+ dstats.embeddings,
242
+ OUR_TEXT_FIELD,
243
+ column_id,
244
+ )
245
+
246
+
247
+ def main():
248
+ """ Sidebar description and selection """
249
+ ds_name_to_dict = dataset_utils.get_dataset_info_dicts()
250
+ st.title("Data Measurements Tool")
251
+ # Get the sidebar details
252
+ st_utils.sidebar_header()
253
+ # Set up naming, configs, and cache path.
254
+ compare_mode = st.sidebar.checkbox("Comparison mode")
255
+
256
+ # When not doing new development, use the cache.
257
+ use_cache = True
258
+ show_embeddings = st.sidebar.checkbox("Show text clusters")
259
+ # List of datasets for which embeddings are hard to compute:
260
+
261
+ if compare_mode:
262
+ logs.warning("Using Comparison Mode")
263
+ dataset_args_left = st_utils.sidebar_selection(ds_name_to_dict, " A")
264
+ dataset_args_right = st_utils.sidebar_selection(ds_name_to_dict, " B")
265
+ left_col, _, right_col = st.columns([10, 1, 10])
266
+ dstats_left, cache_exists_left = load_or_prepare_widgets(
267
+ dataset_args_left, show_embeddings, use_cache=use_cache
268
+ )
269
+ with left_col:
270
+ if cache_exists_left:
271
+ show_column(dstats_left, ds_name_to_dict, show_embeddings, " A")
272
+ else:
273
+ st.markdown("### Missing pre-computed data measures!")
274
+ st.write(dataset_args_left)
275
+ dstats_right, cache_exists_right = load_or_prepare_widgets(
276
+ dataset_args_right, show_embeddings, use_cache=use_cache
277
+ )
278
+ with right_col:
279
+ if cache_exists_right:
280
+ show_column(dstats_right, ds_name_to_dict, show_embeddings, " B")
281
+ else:
282
+ st.markdown("### Missing pre-computed data measures!")
283
+ st.write(dataset_args_right)
284
+ else:
285
+ logs.warning("Using Single Dataset Mode")
286
+ dataset_args = st_utils.sidebar_selection(ds_name_to_dict, "")
287
+ dstats, cache_exists = load_or_prepare_widgets(dataset_args, show_embeddings, use_cache=use_cache)
288
+ if cache_exists:
289
+ show_column(dstats, ds_name_to_dict, show_embeddings, "")
290
+ else:
291
+ st.markdown("### Missing pre-computed data measures!")
292
+ st.write(dataset_args)
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()
Scripts/run.sh ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+
4
+ python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --label_field="label" --feature="text"
5
+ python3 run_data_measurements.py --dataset="hate_speech_offensive" --config="default" --split="train" --label_field="label" --feature="tweet"
6
+
7
+
8
+ python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text"
9
+ python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="unsupervised" --label_field="label" --feature="text"
10
+
11
+
12
+ python3 run_data_measurements.py --dataset="glue" --config="cola" --split="train" --label_field="label" --feature="sentence"
13
+ python3 run_data_measurements.py --dataset="glue" --config="cola" --split="validation" --label_field="label" --feature="sentence"
14
+
15
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="train" --label_field="label" --feature="hypothesis"
16
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="train" --label_field="label" --feature="premise"
17
+
18
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="validation_matched" --label_field="label" --feature="premise"
19
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="validation_matched" --label_field="label" --feature="hypothesis"
20
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="validation_mismatched" --label_field="label" --feature="premise"
21
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="validation_mismatched" --label_field="label" --feature="hypothesis"
22
+
23
+
24
+ python3 run_data_measurements.py --dataset="glue" --config="mrpc" --split="train" --label_field="label" --feature="sentence1"
25
+ python3 run_data_measurements.py --dataset="glue" --config="mrpc" --split="train" --label_field="label" --feature="sentence2"
26
+ python3 run_data_measurements.py --dataset="glue" --config="mrpc" --split="validation" --label_field="label" --feature="sentence1"
27
+ python3 run_data_measurements.py --dataset="glue" --config="mrpc" --split="validation" --label_field="label" --feature="sentence2"
28
+
29
+
30
+ python3 run_data_measurements.py --dataset="glue" --config="rte" --split="train" --label_field="label" --feature="sentence1"
31
+ python3 run_data_measurements.py --dataset="glue" --config="rte" --split="train" --label_field="label" --feature="sentence2"
32
+ python3 run_data_measurements.py --dataset="glue" --config="rte" --split="validation" --label_field="label" --feature="sentence1"
33
+ python3 run_data_measurements.py --dataset="glue" --config="rte" --split="validation" --label_field="label" --feature="sentence2"
34
+
35
+
36
+ python3 run_data_measurements.py --dataset="glue" --config="stsb" --split="train" --label_field="label" --feature="sentence1"
37
+ python3 run_data_measurements.py --dataset="glue" --config="stsb" --split="train" --label_field="label" --feature="sentence2"
38
+ python3 run_data_measurements.py --dataset="glue" --config="stsb" --split="validation" --label_field="label" --feature="sentence1"
39
+ python3 run_data_measurements.py --dataset="glue" --config="stsb" --split="validation" --label_field="label" --feature="sentence2"
40
+
41
+ python3 run_data_measurements.py --dataset="glue" --config="wnli" --split="train" --label_field="label" --feature="sentence1"
42
+ python3 run_data_measurements.py --dataset="glue" --config="wnli" --split="train" --label_field="label" --feature="sentence2"
43
+ python3 run_data_measurements.py --dataset="glue" --config="wnli" --split="validation" --label_field="label" --feature="sentence1"
44
+ python3 run_data_measurements.py --dataset="glue" --config="wnli" --split="validation" --label_field="label" --feature="sentence2"
45
+
46
+ python3 run_data_measurements.py --dataset="glue" --config="sst2" --split="train" --label_field="label" --feature="sentence"
47
+ python3 run_data_measurements.py --dataset="glue" --config="sst2" --split="validation" --label_field="label" --feature="sentence"
48
+
49
+
50
+ python3 run_data_measurements.py --dataset="glue" --config="qnli" --split="train" --label_field="label" --feature="question"
51
+ python3 run_data_measurements.py --dataset="glue" --config="qnli" --split="train" --label_field="label" --feature="sentence"
52
+ python3 run_data_measurements.py --dataset="glue" --config="qnli" --split="validation" --label_field="label" --feature="question"
53
+ python3 run_data_measurements.py --dataset="glue" --config="qnli" --split="validation" --label_field="label" --feature="sentence"
54
+
55
+
56
+ python3 run_data_measurements.py --dataset="glue" --config="qqp" --split="train" --label_field="label" --feature="question1"
57
+ python3 run_data_measurements.py --dataset="glue" --config="qqp" --split="train" --label_field="label" --feature="question2"
58
+ python3 run_data_measurements.py --dataset="glue" --config="qqp" --split="validation" --label_field="label" --feature="question1"
59
+ python3 run_data_measurements.py --dataset="glue" --config="qqp" --split="validation" --label_field="label" --feature="question2"
60
+
61
+ python3 run_data_measurements.py --dataset="glue" --config="mnli_matched" --split="validation" --label_field="label" --feature="hypothesis"
62
+ python3 run_data_measurements.py --dataset="glue" --config="mnli_matched" --split="validation" --label_field="label" --feature="premise"
63
+ python3 run_data_measurements.py --dataset="glue" --config="mnli_mismatched" --split="validation" --label_field="label" --feature="hypothesis"
64
+ python3 run_data_measurements.py --dataset="glue" --config="mnli_mismatched" --split="validation" --label_field="label" --feature="premise"
65
+
66
+
67
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-103-v1" --split="train" --feature="text"
68
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-103-raw-v1" --split="train" --feature="text"
69
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-2-v1" --split="train" --feature="text"
70
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-2-raw-v1" --split="train" --feature="text"
71
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-103-v1" --split="validation" --feature="text"
72
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-103-raw-v1" --split="validation" --feature="text"
73
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-2-v1" --split="validation" --feature="text"
74
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-2-raw-v1" --split="validation" --feature="text"
75
+
76
+
77
+ # Superglue wsc? wic? rte? record? multirc?
78
+
79
+ python3 run_data_measurements.py --dataset="super_glue" --config="boolq" --split="train" --label_field="label" --feature="question"
80
+ python3 run_data_measurements.py --dataset="super_glue" --config="boolq" --split="validation" --label_field="label" --feature="question"
81
+ python3 run_data_measurements.py --dataset="super_glue" --config="boolq" --split="train" --label_field="label" --feature="passage"
82
+ python3 run_data_measurements.py --dataset="super_glue" --config="boolq" --split="validation" --label_field="label" --feature="passage"
83
+
84
+ python3 run_data_measurements.py --dataset="super_glue" --config="cb" --split="train" --label_field="label" --feature="premise"
85
+ python3 run_data_measurements.py --dataset="super_glue" --config="cb" --split="validation" --label_field="label" --feature="premise"
86
+ python3 run_data_measurements.py --dataset="super_glue" --config="cb" --split="train" --label_field="label" --feature="hypothesis"
87
+ python3 run_data_measurements.py --dataset="super_glue" --config="cb" --split="validation" --label_field="label" --feature="hypothesis"
88
+
89
+
90
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="train" --label_field="label" --feature="premise"
91
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="validation" --label_field="label" --feature="premise"
92
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="train" --label_field="label" --feature="choice1"
93
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="validation" --label_field="label" --feature="choice1"
94
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="train" --label_field="label" --feature="choice2"
95
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="validation" --label_field="label" --feature="choice2"
96
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="train" --label_field="label" --feature="question"
97
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="validation" --label_field="label" --feature="question"
98
+
99
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="train" --feature="context"
100
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="train" --feature="question"
101
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="train" --feature="title"
102
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="validation" --feature="context"
103
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="validation" --feature="question"
104
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="validation" --feature="title"
105
+
106
+
107
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="train" --feature="context"
108
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="train" --feature="question"
109
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="train" --feature="title"
110
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="validation" --feature="context"
111
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="validation" --feature="question"
112
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="validation" --feature="title"
app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from os import mkdir
17
+ from os.path import exists, isdir
18
+ from pathlib import Path
19
+
20
+ # #! pip install streamlit
21
+ import streamlit as st
22
+
23
+ # +
24
+ # #! pip install datasets
25
+ # #! pip install powerlaw
26
+ # -
27
+
28
+ from data_measurements import dataset_statistics, dataset_utils
29
+ from data_measurements import streamlit_utils as st_utils
30
+
31
+ logs = logging.getLogger(__name__)
32
+ logs.setLevel(logging.WARNING)
33
+ logs.propagate = False
34
+
35
+ if not logs.handlers:
36
+
37
+ Path('./log_files').mkdir(exist_ok=True)
38
+
39
+ # Logging info to log file
40
+ file = logging.FileHandler("./log_files/app.log")
41
+ fileformat = logging.Formatter("%(asctime)s:%(message)s")
42
+ file.setLevel(logging.INFO)
43
+ file.setFormatter(fileformat)
44
+
45
+ # Logging debug messages to stream
46
+ stream = logging.StreamHandler()
47
+ streamformat = logging.Formatter("[data_measurements_tool] %(message)s")
48
+ stream.setLevel(logging.WARNING)
49
+ stream.setFormatter(streamformat)
50
+
51
+ logs.addHandler(file)
52
+ logs.addHandler(stream)
53
+
54
+ st.set_page_config(
55
+ page_title="Demo to showcase dataset metrics",
56
+ page_icon="https://huggingface.co/front/assets/huggingface_logo.svg",
57
+ layout="wide",
58
+ initial_sidebar_state="auto",
59
+ )
60
+
61
+ # colorblind-friendly colors
62
+ colors = [
63
+ "#332288",
64
+ "#117733",
65
+ "#882255",
66
+ "#AA4499",
67
+ "#CC6677",
68
+ "#44AA99",
69
+ "#DDCC77",
70
+ "#88CCEE",
71
+ ]
72
+
73
+ CACHE_DIR = dataset_utils.CACHE_DIR
74
+ # String names we are using (not coming from the stored dataset).
75
+ OUR_TEXT_FIELD = dataset_utils.OUR_TEXT_FIELD
76
+ OUR_LABEL_FIELD = dataset_utils.OUR_LABEL_FIELD
77
+ TOKENIZED_FIELD = dataset_utils.TOKENIZED_FIELD
78
+ EMBEDDING_FIELD = dataset_utils.EMBEDDING_FIELD
79
+ LENGTH_FIELD = dataset_utils.LENGTH_FIELD
80
+ # TODO: Allow users to specify this.
81
+ _MIN_VOCAB_COUNT = 10
82
+ _SHOW_TOP_N_WORDS = 10
83
+
84
+
85
+ @st.cache(
86
+ hash_funcs={
87
+ dataset_statistics.DatasetStatisticsCacheClass: lambda dstats: dstats.cache_path
88
+ },
89
+ allow_output_mutation=True,
90
+ )
91
+ def load_or_prepare(ds_args, show_embeddings, use_cache=False):
92
+ """
93
+ Takes the dataset arguments from the GUI and uses them to load a dataset from the Hub or, if
94
+ a cache for those arguments is available, to load it from the cache.
95
+ Args:
96
+ ds_args (dict): the dataset arguments defined via the streamlit app GUI
97
+ show_embeddings (Bool): whether embeddings should we loaded and displayed for this dataset
98
+ use_cache (Bool) : whether the cache is used by default or not
99
+ Returns:
100
+ dstats: the computed dataset statistics (from the dataset_statistics class)
101
+ """
102
+ if not isdir(CACHE_DIR):
103
+ logs.warning("Creating cache")
104
+ # We need to preprocess everything.
105
+ # This should eventually all go into a prepare_dataset CLI
106
+ mkdir(CACHE_DIR)
107
+ if use_cache:
108
+ logs.warning("Using cache")
109
+ dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
110
+ logs.warning("Loading dataset")
111
+ dstats.load_or_prepare_dataset()
112
+ if show_embeddings:
113
+ logs.warning("Loading Embeddings")
114
+ dstats.load_or_prepare_embeddings()
115
+ logs.warning("Loading nPMI")
116
+ try:
117
+ dstats.load_or_prepare_npmi()
118
+ except:
119
+ logs.warning("Missing a cache for npmi")
120
+ return dstats
121
+
122
+ @st.cache(
123
+ hash_funcs={
124
+ dataset_statistics.DatasetStatisticsCacheClass: lambda dstats: dstats.cache_path
125
+ },
126
+ allow_output_mutation=True,
127
+ )
128
+ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
129
+ """
130
+ Loader specifically for the widgets used in the app.
131
+ Args:
132
+ ds_args:
133
+ show_embeddings:
134
+ use_cache:
135
+
136
+ Returns:
137
+
138
+ """
139
+
140
+ if use_cache:
141
+ logs.warning("Using cache")
142
+ if True:
143
+ #try:
144
+ dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
145
+ # Don't recalculate; we're live
146
+ dstats.set_deployment(True)
147
+ # checks whether the cache_dir exists in deployment mode
148
+ # creates cache_dir if not and if in development mode
149
+ cache_dir_exists = dstats.check_cache_dir()
150
+ #except:
151
+ # logs.warning("We're screwed")
152
+ if cache_dir_exists:
153
+ try:
154
+ # We need to have the text_dset loaded for further load_or_prepare
155
+ dstats.load_or_prepare_dataset()
156
+ except:
157
+ logs.warning("Missing a cache for load or prepare dataset")
158
+ try:
159
+ # Header widget
160
+ dstats.load_or_prepare_dset_peek()
161
+ except:
162
+ logs.warning("Missing a cache for dset peek")
163
+ if show_embeddings:
164
+ try:
165
+ # Embeddings widget
166
+ dstats.load_or_prepare_embeddings()
167
+ except:
168
+ logs.warning("Missing a cache for embeddings")
169
+ try:
170
+ dstats.load_or_prepare_text_duplicates()
171
+ except:
172
+ logs.warning("Missing a cache for text duplicates")
173
+ try:
174
+ dstats.load_or_prepare_npmi()
175
+ except:
176
+ logs.warning("Missing a cache for npmi")
177
+ return dstats, cache_dir_exists
178
+
179
+ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id):
180
+ """
181
+ Function for displaying the elements in the right column of the streamlit app.
182
+ Args:
183
+ ds_name_to_dict (dict): the dataset name and options in dictionary form
184
+ show_embeddings (Bool): whether embeddings should we loaded and displayed for this dataset
185
+ column_id (str): what column of the dataset the analysis is done on
186
+ Returns:
187
+ The function displays the information using the functions defined in the st_utils class.
188
+ """
189
+ # Note that at this point we assume we can use cache; default value is True.
190
+ # start showing stuff
191
+ title_str = f"### Showing{column_id}: {dstats.dset_name} - {dstats.dset_config} - {dstats.split_name} - {'-'.join(dstats.text_field)}"
192
+ st.markdown(title_str)
193
+ # Uses an interaction; handled a bit differently than other widgets.
194
+ logs.info("showing npmi widget")
195
+ st_utils.npmi_widget(dstats.npmi_stats, _MIN_VOCAB_COUNT, column_id)
196
+ if show_embeddings:
197
+ st_utils.expander_text_embeddings(
198
+ dstats.text_dset,
199
+ dstats.fig_tree,
200
+ dstats.node_list,
201
+ dstats.embeddings,
202
+ OUR_TEXT_FIELD,
203
+ column_id,
204
+ )
205
+
206
+
207
+ def main():
208
+ """ Sidebar description and selection """
209
+ ds_name_to_dict = dataset_utils.get_dataset_info_dicts()
210
+ st.title("Data Measurements Tool")
211
+ # Get the sidebar details
212
+ st_utils.sidebar_header()
213
+ # Set up naming, configs, and cache path.
214
+ compare_mode = st.sidebar.checkbox("Comparison mode")
215
+
216
+ # When not doing new development, use the cache.
217
+ use_cache = True
218
+ show_embeddings = st.sidebar.checkbox("Show text clusters")
219
+ # List of datasets for which embeddings are hard to compute:
220
+
221
+ if compare_mode:
222
+ logs.warning("Using Comparison Mode")
223
+ dataset_args_left = st_utils.sidebar_selection(ds_name_to_dict, " A")
224
+ dataset_args_right = st_utils.sidebar_selection(ds_name_to_dict, " B")
225
+ left_col, _, right_col = st.columns([10, 1, 10])
226
+ dstats_left, cache_exists_left = load_or_prepare_widgets(
227
+ dataset_args_left, show_embeddings, use_cache=use_cache
228
+ )
229
+ with left_col:
230
+ if cache_exists_left:
231
+ show_column(dstats_left, ds_name_to_dict, show_embeddings, " A")
232
+ else:
233
+ st.markdown("### Missing pre-computed data measures!")
234
+ st.write(dataset_args_left)
235
+ dstats_right, cache_exists_right = load_or_prepare_widgets(
236
+ dataset_args_right, show_embeddings, use_cache=use_cache
237
+ )
238
+ with right_col:
239
+ if cache_exists_right:
240
+ show_column(dstats_right, ds_name_to_dict, show_embeddings, " B")
241
+ else:
242
+ st.markdown("### Missing pre-computed data measures!")
243
+ st.write(dataset_args_right)
244
+ else:
245
+ logs.warning("Using Single Dataset Mode")
246
+ dataset_args = st_utils.sidebar_selection(ds_name_to_dict, "")
247
+ dstats, cache_exists = load_or_prepare_widgets(dataset_args, show_embeddings, use_cache=use_cache)
248
+ if cache_exists:
249
+ show_column(dstats, ds_name_to_dict, show_embeddings, "")
250
+ else:
251
+ st.markdown("### Missing pre-computed data measures!")
252
+ st.write(dataset_args)
253
+
254
+
255
+ if __name__ == "__main__":
256
+ main()
data_measurements/__init__.py ADDED
File without changes
data_measurements/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (171 Bytes). View file
 
data_measurements/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (170 Bytes). View file
 
data_measurements/__pycache__/dataset_statistics.cpython-310.pyc ADDED
Binary file (32.6 kB). View file
 
data_measurements/__pycache__/dataset_statistics.cpython-311.pyc ADDED
Binary file (62.3 kB). View file
 
data_measurements/__pycache__/dataset_utils.cpython-310.pyc ADDED
Binary file (7.1 kB). View file
 
data_measurements/__pycache__/dataset_utils.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
data_measurements/__pycache__/embeddings.cpython-310.pyc ADDED
Binary file (16.5 kB). View file
 
data_measurements/__pycache__/embeddings.cpython-311.pyc ADDED
Binary file (28.9 kB). View file
 
data_measurements/__pycache__/npmi.cpython-310.pyc ADDED
Binary file (6.3 kB). View file
 
data_measurements/__pycache__/npmi.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
data_measurements/__pycache__/streamlit_utils.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
data_measurements/__pycache__/streamlit_utils.cpython-311.pyc ADDED
Binary file (27.8 kB). View file
 
data_measurements/__pycache__/zipf.cpython-310.pyc ADDED
Binary file (7.26 kB). View file
 
data_measurements/__pycache__/zipf.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
data_measurements/_pycache_/__init__.cpython-311.pyc ADDED
Binary file (237 Bytes). View file
 
data_measurements/_pycache_/__init__.cpython-37.pyc ADDED
Binary file (166 Bytes). View file
 
data_measurements/_pycache_/dataset_statistics.cpython-311.pyc ADDED
Binary file (62.4 kB). View file
 
data_measurements/_pycache_/dataset_statistics.cpython-37.pyc ADDED
Binary file (31.6 kB). View file
 
data_measurements/_pycache_/dataset_utils.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
data_measurements/_pycache_/dataset_utils.cpython-37.pyc ADDED
Binary file (7.04 kB). View file
 
data_measurements/_pycache_/embeddings.cpython-311.pyc ADDED
Binary file (28.9 kB). View file
 
data_measurements/_pycache_/embeddings.cpython-37.pyc ADDED
Binary file (16.5 kB). View file
 
data_measurements/_pycache_/npmi.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
data_measurements/_pycache_/npmi.cpython-37.pyc ADDED
Binary file (6.23 kB). View file
 
data_measurements/_pycache_/streamlit_utils.cpython-311.pyc ADDED
Binary file (27.8 kB). View file
 
data_measurements/_pycache_/zipf.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
data_measurements/_pycache_/zipf.cpython-37.pyc ADDED
Binary file (7.24 kB). View file
 
data_measurements/dataset_statistics.py ADDED
@@ -0,0 +1,1223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import statistics
18
+ from os import mkdir
19
+ from os.path import exists, isdir
20
+ from os.path import join as pjoin
21
+
22
+ import matplotlib.pyplot as plt
23
+ import matplotlib.image as mpimg
24
+ import nltk
25
+ import numpy as np
26
+ import pandas as pd
27
+ import plotly
28
+ import plotly.express as px
29
+ import plotly.figure_factory as ff
30
+ import plotly.graph_objects as go
31
+ import pyarrow.feather as feather
32
+ import seaborn as sns
33
+ import torch
34
+ from datasets import load_from_disk
35
+ from nltk.corpus import stopwords
36
+ from sklearn.feature_extraction.text import CountVectorizer
37
+
38
+ from .dataset_utils import (CNT, DEDUP_TOT, EMBEDDING_FIELD, LENGTH_FIELD,
39
+ OUR_LABEL_FIELD, OUR_TEXT_FIELD, PROP,
40
+ TEXT_NAN_CNT, TOKENIZED_FIELD, TOT_OPEN_WORDS,
41
+ TOT_WORDS, TXT_LEN, VOCAB, WORD, extract_field,
42
+ load_truncated_dataset)
43
+ from .embeddings import Embeddings
44
+ from .npmi import nPMI
45
+ from .zipf import Zipf
46
+
47
+ pd.options.display.float_format = "{:,.3f}".format
48
+
49
+ logs = logging.getLogger(__name__)
50
+ logs.setLevel(logging.WARNING)
51
+ logs.propagate = False
52
+
53
+ if not logs.handlers:
54
+
55
+ # Logging info to log file
56
+ file = logging.FileHandler("./log_files/dataset_statistics.log")
57
+ fileformat = logging.Formatter("%(asctime)s:%(message)s")
58
+ file.setLevel(logging.INFO)
59
+ file.setFormatter(fileformat)
60
+
61
+ # Logging debug messages to stream
62
+ stream = logging.StreamHandler()
63
+ streamformat = logging.Formatter("[data_measurements_tool] %(message)s")
64
+ stream.setLevel(logging.WARNING)
65
+ stream.setFormatter(streamformat)
66
+
67
+ logs.addHandler(file)
68
+ logs.addHandler(stream)
69
+
70
+
71
+ # TODO: Read this in depending on chosen language / expand beyond english
72
+ nltk.download("stopwords")
73
+ _CLOSED_CLASS = (
74
+ stopwords.words("english")
75
+ + [
76
+ "t",
77
+ "n",
78
+ "ll",
79
+ "d",
80
+ "wasn",
81
+ "weren",
82
+ "won",
83
+ "aren",
84
+ "wouldn",
85
+ "shouldn",
86
+ "didn",
87
+ "don",
88
+ "hasn",
89
+ "ain",
90
+ "couldn",
91
+ "doesn",
92
+ "hadn",
93
+ "haven",
94
+ "isn",
95
+ "mightn",
96
+ "mustn",
97
+ "needn",
98
+ "shan",
99
+ "would",
100
+ "could",
101
+ "dont",
102
+ "u",
103
+ ]
104
+ + [str(i) for i in range(0, 21)]
105
+ )
106
+ _IDENTITY_TERMS = [
107
+ "man",
108
+ "woman",
109
+ "non-binary",
110
+ "gay",
111
+ "lesbian",
112
+ "queer",
113
+ "trans",
114
+ "straight",
115
+ "cis",
116
+ "she",
117
+ "her",
118
+ "hers",
119
+ "he",
120
+ "him",
121
+ "his",
122
+ "they",
123
+ "them",
124
+ "their",
125
+ "theirs",
126
+ "himself",
127
+ "herself",
128
+ ]
129
+ # treating inf values as NaN as well
130
+ pd.set_option("use_inf_as_na", True)
131
+
132
+ _MIN_VOCAB_COUNT = 10
133
+ _TREE_DEPTH = 12
134
+ _TREE_MIN_NODES = 250
135
+ # as long as we're using sklearn - already pushing the resources
136
+ _MAX_CLUSTER_EXAMPLES = 5000
137
+ _NUM_VOCAB_BATCHES = 2000
138
+ _TOP_N = 100
139
+ _CVEC = CountVectorizer(token_pattern="(?u)\\b\\w+\\b", lowercase=True)
140
+
141
+
142
+ class DatasetStatisticsCacheClass:
143
+ def __init__(
144
+ self,
145
+ cache_dir,
146
+ dset_name,
147
+ dset_config,
148
+ split_name,
149
+ text_field,
150
+ label_field,
151
+ label_names,
152
+ calculation=None,
153
+ use_cache=False,
154
+ ):
155
+ # This is only used for standalone runs for each kind of measurement.
156
+ self.calculation = calculation
157
+ self.our_text_field = OUR_TEXT_FIELD
158
+ self.our_length_field = LENGTH_FIELD
159
+ self.our_label_field = OUR_LABEL_FIELD
160
+ self.our_tokenized_field = TOKENIZED_FIELD
161
+ self.our_embedding_field = EMBEDDING_FIELD
162
+ self.cache_dir = cache_dir
163
+ # Use stored data if there; otherwise calculate afresh
164
+ self.use_cache = use_cache
165
+ ### What are we analyzing?
166
+ # name of the Hugging Face dataset
167
+ self.dset_name = dset_name
168
+ # name of the dataset config
169
+ self.dset_config = dset_config
170
+ # name of the split to analyze
171
+ self.split_name = split_name
172
+ # TODO: Chould this be "feature" ?
173
+ # which text fields are we analysing?
174
+ self.text_field = text_field
175
+ # which label fields are we analysing?
176
+ self.label_field = label_field
177
+ # what are the names of the classes?
178
+ self.label_names = label_names
179
+ ## Hugging Face dataset objects
180
+ self.dset = None # original dataset
181
+ # HF dataset with all of the self.text_field instances in self.dset
182
+ self.text_dset = None
183
+ self.dset_peek = None
184
+ # HF dataset with text embeddings in the same order as self.text_dset
185
+ self.embeddings_dset = None
186
+ # HF dataset with all of the self.label_field instances in self.dset
187
+ self.label_dset = None
188
+ ## Data frames
189
+ # Tokenized text
190
+ self.tokenized_df = None
191
+ # save sentence length histogram in the class so it doesn't ge re-computed
192
+ self.length_df = None
193
+ self.fig_tok_length = None
194
+ # Data Frame version of self.label_dset
195
+ self.label_df = None
196
+ # save label pie chart in the class so it doesn't ge re-computed
197
+ self.fig_labels = None
198
+ # Vocabulary with word counts in the dataset
199
+ self.vocab_counts_df = None
200
+ # Vocabulary filtered to remove stopwords
201
+ self.vocab_counts_filtered_df = None
202
+ self.sorted_top_vocab_df = None
203
+ ## General statistics and duplicates
204
+ self.total_words = 0
205
+ self.total_open_words = 0
206
+ # Number of NaN values (NOT empty strings)
207
+ self.text_nan_count = 0
208
+ # Number of text items that appear more than once in the dataset
209
+ self.dedup_total = 0
210
+ # Duplicated text items along with their number of occurences ("count")
211
+ self.dup_counts_df = None
212
+ self.avg_length = None
213
+ self.std_length = None
214
+ self.general_stats_dict = None
215
+ self.num_uniq_lengths = 0
216
+ # clustering text by embeddings
217
+ # the hierarchical clustering tree is represented as a list of nodes,
218
+ # the first is the root
219
+ self.node_list = []
220
+ # save tree figure in the class so it doesn't ge re-computed
221
+ self.fig_tree = None
222
+ # keep Embeddings object around to explore clusters
223
+ self.embeddings = None
224
+ # nPMI
225
+ # Holds a nPMIStatisticsCacheClass object
226
+ self.npmi_stats = None
227
+ # TODO: Have lowercase be an option for a user to set.
228
+ self.to_lowercase = True
229
+ # The minimum amount of times a word should occur to be included in
230
+ # word-count-based calculations (currently just relevant to nPMI)
231
+ self.min_vocab_count = _MIN_VOCAB_COUNT
232
+ # zipf
233
+ self.z = None
234
+ self.zipf_fig = None
235
+ self.cvec = _CVEC
236
+ # File definitions
237
+ # path to the directory used for caching
238
+ if not isinstance(text_field, str):
239
+ text_field = "-".join(text_field)
240
+ # if isinstance(label_field, str):
241
+ # label_field = label_field
242
+ # else:
243
+ # label_field = "-".join(label_field)
244
+ self.cache_path = pjoin(
245
+ self.cache_dir,
246
+ f"{dset_name}_{dset_config}_{split_name}_{text_field}", # {label_field},
247
+ )
248
+
249
+ # Cache files not needed for UI
250
+ self.dset_fid = pjoin(self.cache_path, "base_dset")
251
+ self.tokenized_df_fid = pjoin(self.cache_path, "tokenized_df.feather")
252
+ self.label_dset_fid = pjoin(self.cache_path, "label_dset")
253
+
254
+ # Needed for UI -- embeddings
255
+ self.text_dset_fid = pjoin(self.cache_path, "text_dset")
256
+ # Needed for UI
257
+ self.dset_peek_json_fid = pjoin(self.cache_path, "dset_peek.json")
258
+
259
+ ## Label cache files.
260
+ # Needed for UI
261
+ self.fig_labels_json_fid = pjoin(self.cache_path, "fig_labels.json")
262
+
263
+ ## Length cache files
264
+ # Needed for UI
265
+ self.length_df_fid = pjoin(self.cache_path, "length_df.feather")
266
+ # Needed for UI
267
+ self.length_stats_json_fid = pjoin(self.cache_path, "length_stats.json")
268
+ self.vocab_counts_df_fid = pjoin(self.cache_path, "vocab_counts.feather")
269
+ # Needed for UI
270
+ self.dup_counts_df_fid = pjoin(self.cache_path, "dup_counts_df.feather")
271
+ # Needed for UI
272
+ self.fig_tok_length_fid = pjoin(self.cache_path, "fig_tok_length.png")
273
+
274
+ ## General text stats
275
+ # Needed for UI
276
+ self.general_stats_json_fid = pjoin(self.cache_path, "general_stats_dict.json")
277
+ # Needed for UI
278
+ self.sorted_top_vocab_df_fid = pjoin(
279
+ self.cache_path, "sorted_top_vocab.feather"
280
+ )
281
+ ## Zipf cache files
282
+ # Needed for UI
283
+ self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
284
+ # Needed for UI
285
+ self.zipf_fig_fid = pjoin(self.cache_path, "zipf_fig.json")
286
+
287
+ ## Embeddings cache files
288
+ # Needed for UI
289
+ self.node_list_fid = pjoin(self.cache_path, "node_list.th")
290
+ # Needed for UI
291
+ self.fig_tree_json_fid = pjoin(self.cache_path, "fig_tree.json")
292
+
293
+ self.live = False
294
+
295
+ def set_deployment(self, live=True):
296
+ """
297
+ Function that we can hit when we deploy, so that cache files are not
298
+ written out/recalculated, but instead that part of the UI can be punted.
299
+ """
300
+ self.live = live
301
+
302
+ def check_cache_dir(self):
303
+ """
304
+ First function to call to create the cache directory.
305
+ If in deployment mode and cache directory does not already exist,
306
+ return False.
307
+ """
308
+ if self.live:
309
+ return isdir(self.cache_path)
310
+ else:
311
+ if not isdir(self.cache_path):
312
+ logs.warning("Creating cache directory %s." % self.cache_path)
313
+ mkdir(self.cache_path)
314
+ return isdir(self.cache_path)
315
+
316
+
317
+ def get_base_dataset(self):
318
+ """Gets a pointer to the truncated base dataset object."""
319
+ if not self.dset:
320
+ self.dset = load_truncated_dataset(
321
+ self.dset_name,
322
+ self.dset_config,
323
+ self.split_name,
324
+ cache_name=self.dset_fid,
325
+ use_cache=True,
326
+ use_streaming=True,
327
+ )
328
+
329
+ def load_or_prepare_general_stats(self, save=True):
330
+ """
331
+ Content for expander_general_stats widget.
332
+ Provides statistics for total words, total open words,
333
+ the sorted top vocab, the NaN count, and the duplicate count.
334
+ Args:
335
+
336
+ Returns:
337
+
338
+ """
339
+ # General statistics
340
+ if (
341
+ self.use_cache
342
+ and exists(self.general_stats_json_fid)
343
+ and exists(self.dup_counts_df_fid)
344
+ and exists(self.sorted_top_vocab_df_fid)
345
+ ):
346
+ logs.info("Loading cached general stats")
347
+ self.load_general_stats()
348
+ else:
349
+ if not self.live:
350
+ logs.info("Preparing general stats")
351
+ self.prepare_general_stats()
352
+ if save:
353
+ write_df(self.sorted_top_vocab_df, self.sorted_top_vocab_df_fid)
354
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
355
+ write_json(self.general_stats_dict, self.general_stats_json_fid)
356
+
357
+ def load_or_prepare_text_lengths(self, save=True):
358
+ """
359
+ The text length widget relies on this function, which provides
360
+ a figure of the text lengths, some text length statistics, and
361
+ a text length dataframe to peruse.
362
+ Args:
363
+ save:
364
+ Returns:
365
+
366
+ """
367
+ # Text length figure
368
+ if self.use_cache and exists(self.fig_tok_length_fid):
369
+ self.fig_tok_length_png = mpimg.imread(self.fig_tok_length_fid)
370
+ else:
371
+ if not self.live:
372
+ self.prepare_fig_text_lengths()
373
+ if save:
374
+ self.fig_tok_length.savefig(self.fig_tok_length_fid)
375
+ # Text length dataframe
376
+ if self.use_cache and exists(self.length_df_fid):
377
+ self.length_df = feather.read_feather(self.length_df_fid)
378
+ else:
379
+ if not self.live:
380
+ self.prepare_length_df()
381
+ if save:
382
+ write_df(self.length_df, self.length_df_fid)
383
+
384
+ # Text length stats.
385
+ if self.use_cache and exists(self.length_stats_json_fid):
386
+ with open(self.length_stats_json_fid, "r") as f:
387
+ self.length_stats_dict = json.load(f)
388
+ self.avg_length = self.length_stats_dict["avg length"]
389
+ self.std_length = self.length_stats_dict["std length"]
390
+ self.num_uniq_lengths = self.length_stats_dict["num lengths"]
391
+ else:
392
+ if not self.live:
393
+ self.prepare_text_length_stats()
394
+ if save:
395
+ write_json(self.length_stats_dict, self.length_stats_json_fid)
396
+
397
+ def prepare_length_df(self):
398
+ if not self.live:
399
+ if self.tokenized_df is None:
400
+ self.tokenized_df = self.do_tokenization()
401
+ self.tokenized_df[LENGTH_FIELD] = self.tokenized_df[TOKENIZED_FIELD].apply(
402
+ len
403
+ )
404
+ self.length_df = self.tokenized_df[
405
+ [LENGTH_FIELD, OUR_TEXT_FIELD]
406
+ ].sort_values(by=[LENGTH_FIELD], ascending=True)
407
+
408
+ def prepare_text_length_stats(self):
409
+ if not self.live:
410
+ if (
411
+ self.tokenized_df is None
412
+ or LENGTH_FIELD not in self.tokenized_df.columns
413
+ or self.length_df is None
414
+ ):
415
+ self.prepare_length_df()
416
+ avg_length = sum(self.tokenized_df[LENGTH_FIELD]) / len(
417
+ self.tokenized_df[LENGTH_FIELD]
418
+ )
419
+ self.avg_length = round(avg_length, 1)
420
+ std_length = statistics.stdev(self.tokenized_df[LENGTH_FIELD])
421
+ self.std_length = round(std_length, 1)
422
+ self.num_uniq_lengths = len(self.length_df["length"].unique())
423
+ self.length_stats_dict = {
424
+ "avg length": self.avg_length,
425
+ "std length": self.std_length,
426
+ "num lengths": self.num_uniq_lengths,
427
+ }
428
+
429
+ def prepare_fig_text_lengths(self):
430
+ if not self.live:
431
+ if (
432
+ self.tokenized_df is None
433
+ or LENGTH_FIELD not in self.tokenized_df.columns
434
+ ):
435
+ self.prepare_length_df()
436
+ self.fig_tok_length = make_fig_lengths(self.tokenized_df, LENGTH_FIELD)
437
+
438
+ def load_or_prepare_embeddings(self):
439
+ self.embeddings = Embeddings(self, use_cache=self.use_cache)
440
+ self.embeddings.make_hierarchical_clustering()
441
+ self.node_list = self.embeddings.node_list
442
+ self.fig_tree = self.embeddings.fig_tree
443
+
444
+ # get vocab with word counts
445
+ def load_or_prepare_vocab(self, save=True):
446
+ """
447
+ Calculates the vocabulary count from the tokenized text.
448
+ The resulting dataframes may be used in nPMI calculations, zipf, etc.
449
+ :param
450
+ :return:
451
+ """
452
+ if self.use_cache and exists(self.vocab_counts_df_fid):
453
+ logs.info("Reading vocab from cache")
454
+ self.load_vocab()
455
+ self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
456
+ else:
457
+ logs.info("Calculating vocab afresh")
458
+ if self.tokenized_df is None:
459
+ self.tokenized_df = self.do_tokenization()
460
+ if save:
461
+ logs.info("Writing out.")
462
+ write_df(self.tokenized_df, self.tokenized_df_fid)
463
+ word_count_df = count_vocab_frequencies(self.tokenized_df)
464
+ logs.info("Making dfs with proportion.")
465
+ self.vocab_counts_df = calc_p_word(word_count_df)
466
+ self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
467
+ if save:
468
+ logs.info("Writing out.")
469
+ write_df(self.vocab_counts_df, self.vocab_counts_df_fid)
470
+ logs.info("unfiltered vocab")
471
+ logs.info(self.vocab_counts_df)
472
+ logs.info("filtered vocab")
473
+ logs.info(self.vocab_counts_filtered_df)
474
+
475
+ def load_vocab(self):
476
+ with open(self.vocab_counts_df_fid, "rb") as f:
477
+ self.vocab_counts_df = feather.read_feather(f)
478
+ # Handling for changes in how the index is saved.
479
+ self.vocab_counts_df = self._set_idx_col_names(self.vocab_counts_df)
480
+
481
+ def load_or_prepare_text_duplicates(self, save=True):
482
+ if self.use_cache and exists(self.dup_counts_df_fid):
483
+ with open(self.dup_counts_df_fid, "rb") as f:
484
+ self.dup_counts_df = feather.read_feather(f)
485
+ elif self.dup_counts_df is None:
486
+ if not self.live:
487
+ self.prepare_text_duplicates()
488
+ if save:
489
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
490
+ else:
491
+ if not self.live:
492
+ # This happens when self.dup_counts_df is already defined;
493
+ # This happens when general_statistics were calculated first,
494
+ # since general statistics requires the number of duplicates
495
+ if save:
496
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
497
+
498
+ def load_general_stats(self):
499
+ self.general_stats_dict = json.load(
500
+ open(self.general_stats_json_fid, encoding="utf-8")
501
+ )
502
+ with open(self.sorted_top_vocab_df_fid, "rb") as f:
503
+ self.sorted_top_vocab_df = feather.read_feather(f)
504
+ self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
505
+ self.dedup_total = self.general_stats_dict[DEDUP_TOT]
506
+ self.total_words = self.general_stats_dict[TOT_WORDS]
507
+ self.total_open_words = self.general_stats_dict[TOT_OPEN_WORDS]
508
+
509
+ def prepare_general_stats(self):
510
+ if not self.live:
511
+ if self.tokenized_df is None:
512
+ logs.warning("Tokenized dataset not yet loaded; doing so.")
513
+ self.load_or_prepare_tokenized_df()
514
+ if self.vocab_counts_df is None:
515
+ logs.warning("Vocab not yet loaded; doing so.")
516
+ self.load_or_prepare_vocab()
517
+ self.sorted_top_vocab_df = self.vocab_counts_filtered_df.sort_values(
518
+ "count", ascending=False
519
+ ).head(_TOP_N)
520
+ self.total_words = len(self.vocab_counts_df)
521
+ self.total_open_words = len(self.vocab_counts_filtered_df)
522
+ self.text_nan_count = int(self.tokenized_df.isnull().sum().sum())
523
+ self.prepare_text_duplicates()
524
+ self.dedup_total = sum(self.dup_counts_df[CNT])
525
+ self.general_stats_dict = {
526
+ TOT_WORDS: self.total_words,
527
+ TOT_OPEN_WORDS: self.total_open_words,
528
+ TEXT_NAN_CNT: self.text_nan_count,
529
+ DEDUP_TOT: self.dedup_total,
530
+ }
531
+
532
+ def prepare_text_duplicates(self):
533
+ if not self.live:
534
+ if self.tokenized_df is None:
535
+ self.load_or_prepare_tokenized_df()
536
+ dup_df = self.tokenized_df[self.tokenized_df.duplicated([OUR_TEXT_FIELD])]
537
+ self.dup_counts_df = pd.DataFrame(
538
+ dup_df.pivot_table(
539
+ columns=[OUR_TEXT_FIELD], aggfunc="size"
540
+ ).sort_values(ascending=False),
541
+ columns=[CNT],
542
+ )
543
+ self.dup_counts_df[OUR_TEXT_FIELD] = self.dup_counts_df.index.copy()
544
+
545
+ def load_or_prepare_dataset(self, save=True):
546
+ """
547
+ Prepares the HF datasets and data frames containing the untokenized and
548
+ tokenized text as well as the label values.
549
+ self.tokenized_df is used further for calculating text lengths,
550
+ word counts, etc.
551
+ Args:
552
+ save: Store the calculated data to disk.
553
+
554
+ Returns:
555
+
556
+ """
557
+ logs.info("Doing text dset.")
558
+ self.load_or_prepare_text_dset(save)
559
+ #logs.info("Doing tokenized dataframe")
560
+ #self.load_or_prepare_tokenized_df(save)
561
+ logs.info("Doing dataset peek")
562
+ self.load_or_prepare_dset_peek(save)
563
+
564
+ def load_or_prepare_dset_peek(self, save=True):
565
+ if self.use_cache and exists(self.dset_peek_json_fid):
566
+ with open(self.dset_peek_json_fid, "r") as f:
567
+ self.dset_peek = json.load(f)["dset peek"]
568
+ else:
569
+ if not self.live:
570
+ if self.dset is None:
571
+ self.get_base_dataset()
572
+ self.dset_peek = self.dset[:100]
573
+ if save:
574
+ write_json({"dset peek": self.dset_peek}, self.dset_peek_json_fid)
575
+
576
+ def load_or_prepare_tokenized_df(self, save=True):
577
+ if self.use_cache and exists(self.tokenized_df_fid):
578
+ self.tokenized_df = feather.read_feather(self.tokenized_df_fid)
579
+ else:
580
+ if not self.live:
581
+ # tokenize all text instances
582
+ self.tokenized_df = self.do_tokenization()
583
+ if save:
584
+ logs.warning("Saving tokenized dataset to disk")
585
+ # save tokenized text
586
+ write_df(self.tokenized_df, self.tokenized_df_fid)
587
+
588
+ def load_or_prepare_text_dset(self, save=True):
589
+ if self.use_cache and exists(self.text_dset_fid):
590
+ # load extracted text
591
+ self.text_dset = load_from_disk(self.text_dset_fid)
592
+ logs.warning("Loaded dataset from disk")
593
+ logs.info(self.text_dset)
594
+ # ...Or load it from the server and store it anew
595
+ else:
596
+ if not self.live:
597
+ self.prepare_text_dset()
598
+ if save:
599
+ # save extracted text instances
600
+ logs.warning("Saving dataset to disk")
601
+ self.text_dset.save_to_disk(self.text_dset_fid)
602
+
603
+ def prepare_text_dset(self):
604
+ if not self.live:
605
+ self.get_base_dataset()
606
+ # extract all text instances
607
+ self.text_dset = self.dset.map(
608
+ lambda examples: extract_field(
609
+ examples, self.text_field, OUR_TEXT_FIELD
610
+ ),
611
+ batched=True,
612
+ remove_columns=list(self.dset.features),
613
+ )
614
+ ##additon
615
+ self.text_dset = self.text_dset.filter(lambda ex: ex["text"] is not None)
616
+
617
+ def do_tokenization(self):
618
+ """
619
+ Tokenizes the dataset
620
+ :return:
621
+ """
622
+ if self.text_dset is None:
623
+ self.load_or_prepare_text_dset()
624
+ sent_tokenizer = self.cvec.build_tokenizer()
625
+
626
+ def tokenize_batch(examples):
627
+ # TODO: lowercase should be an option
628
+ res = {
629
+ TOKENIZED_FIELD: [
630
+ tuple(sent_tokenizer(text.lower()))
631
+ for text in examples[OUR_TEXT_FIELD]
632
+ ]
633
+ }
634
+ res[LENGTH_FIELD] = [len(tok_text) for tok_text in res[TOKENIZED_FIELD]]
635
+ return res
636
+
637
+ tokenized_dset = self.text_dset.map(
638
+ tokenize_batch,
639
+ batched=True,
640
+ # remove_columns=[OUR_TEXT_FIELD], keep around to print
641
+ )
642
+ tokenized_df = pd.DataFrame(tokenized_dset)
643
+ return tokenized_df
644
+
645
+ def set_label_field(self, label_field="label"):
646
+ """
647
+ Setter for label_field. Used in the CLI when a user asks for information
648
+ about labels, but does not specify the field;
649
+ 'label' is assumed as a default.
650
+ """
651
+ self.label_field = label_field
652
+
653
+ def load_or_prepare_labels(self, save=True):
654
+ # TODO: This is in a transitory state for creating fig cache.
655
+ # Clean up to be caching and reading everything correctly.
656
+ """
657
+ Extracts labels from the Dataset
658
+ :return:
659
+ """
660
+ # extracted labels
661
+ if len(self.label_field) > 0:
662
+ if self.use_cache and exists(self.fig_labels_json_fid):
663
+ self.fig_labels = read_plotly(self.fig_labels_json_fid)
664
+ elif self.use_cache and exists(self.label_dset_fid):
665
+ # load extracted labels
666
+ self.label_dset = load_from_disk(self.label_dset_fid)
667
+ self.label_df = self.label_dset.to_pandas()
668
+ self.fig_labels = make_fig_labels(
669
+ self.label_df, self.label_names, OUR_LABEL_FIELD
670
+ )
671
+ if save:
672
+ write_plotly(self.fig_labels, self.fig_labels_json_fid)
673
+ else:
674
+ if not self.live:
675
+ self.prepare_labels()
676
+ if save:
677
+ # save extracted label instances
678
+ self.label_dset.save_to_disk(self.label_dset_fid)
679
+ write_plotly(self.fig_labels, self.fig_labels_json_fid)
680
+
681
+ def prepare_labels(self):
682
+ if not self.live:
683
+ self.get_base_dataset()
684
+ self.label_dset = self.dset.map(
685
+ lambda examples: extract_field(
686
+ examples, self.label_field, OUR_LABEL_FIELD
687
+ ),
688
+ batched=True,
689
+ remove_columns=list(self.dset.features),
690
+ )
691
+ self.label_df = self.label_dset.to_pandas()
692
+ self.fig_labels = make_fig_labels(
693
+ self.label_df, self.label_names, OUR_LABEL_FIELD
694
+ )
695
+
696
+ def load_or_prepare_npmi(self):
697
+ self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
698
+ self.npmi_stats.load_or_prepare_npmi_terms()
699
+
700
+ def load_or_prepare_zipf(self, save=True):
701
+ # TODO: Current UI only uses the fig, meaning the self.z here is irrelevant
702
+ # when only reading from cache. Either the UI should use it, or it should
703
+ # be removed when reading in cache
704
+ if self.use_cache and exists(self.zipf_fig_fid) and exists(self.zipf_fid):
705
+ with open(self.zipf_fid, "r") as f:
706
+ zipf_dict = json.load(f)
707
+ self.z = Zipf()
708
+ self.z.load(zipf_dict)
709
+ self.zipf_fig = read_plotly(self.zipf_fig_fid)
710
+ elif self.use_cache and exists(self.zipf_fid):
711
+ # TODO: Read zipf data so that the vocab is there.
712
+ with open(self.zipf_fid, "r") as f:
713
+ zipf_dict = json.load(f)
714
+ self.z = Zipf()
715
+ self.z.load(zipf_dict)
716
+ self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
717
+ if save:
718
+ write_plotly(self.zipf_fig, self.zipf_fig_fid)
719
+ else:
720
+ self.z = Zipf(self.vocab_counts_df)
721
+ self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
722
+ if save:
723
+ write_zipf_data(self.z, self.zipf_fid)
724
+ write_plotly(self.zipf_fig, self.zipf_fig_fid)
725
+
726
+ def _set_idx_col_names(self, input_vocab_df):
727
+ if input_vocab_df.index.name != VOCAB and VOCAB in input_vocab_df.columns:
728
+ input_vocab_df = input_vocab_df.set_index([VOCAB])
729
+ input_vocab_df[VOCAB] = input_vocab_df.index
730
+ return input_vocab_df
731
+
732
+
733
+ class nPMIStatisticsCacheClass:
734
+ """ "Class to interface between the app and the nPMI class
735
+ by calling the nPMI class with the user's selections."""
736
+
737
+ def __init__(self, dataset_stats, use_cache=False):
738
+ self.live = dataset_stats.live
739
+ self.dstats = dataset_stats
740
+ self.pmi_cache_path = pjoin(self.dstats.cache_path, "pmi_files")
741
+ if not isdir(self.pmi_cache_path):
742
+ logs.warning("Creating pmi cache directory %s." % self.pmi_cache_path)
743
+ # We need to preprocess everything.
744
+ mkdir(self.pmi_cache_path)
745
+ self.joint_npmi_df_dict = {}
746
+ # TODO: Users ideally can type in whatever words they want.
747
+ self.termlist = _IDENTITY_TERMS
748
+ # termlist terms that are available more than _MIN_VOCAB_COUNT times
749
+ self.available_terms = _IDENTITY_TERMS
750
+ logs.info(self.termlist)
751
+ self.use_cache = use_cache
752
+ # TODO: Let users specify
753
+ self.open_class_only = True
754
+ self.min_vocab_count = self.dstats.min_vocab_count
755
+ self.subgroup_files = {}
756
+ self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
757
+
758
+ def load_or_prepare_npmi_terms(self):
759
+ """
760
+ Figures out what identity terms the user can select, based on whether
761
+ they occur more than self.min_vocab_count times
762
+ :return: Identity terms occurring at least self.min_vocab_count times.
763
+ """
764
+ # TODO: Add the user's ability to select subgroups.
765
+ # TODO: Make min_vocab_count here value selectable by the user.
766
+ if (
767
+ self.use_cache
768
+ and exists(self.npmi_terms_fid)
769
+ and json.load(open(self.npmi_terms_fid))["available terms"] != []
770
+ ):
771
+ available_terms = json.load(open(self.npmi_terms_fid))["available terms"]
772
+ else:
773
+ true_false = [
774
+ term in self.dstats.vocab_counts_df.index for term in self.termlist
775
+ ]
776
+ word_list_tmp = [x for x, y in zip(self.termlist, true_false) if y]
777
+ true_false_counts = [
778
+ self.dstats.vocab_counts_df.loc[word, CNT] >= self.min_vocab_count
779
+ for word in word_list_tmp
780
+ ]
781
+ available_terms = [
782
+ word for word, y in zip(word_list_tmp, true_false_counts) if y
783
+ ]
784
+ logs.info(available_terms)
785
+ with open(self.npmi_terms_fid, "w+") as f:
786
+ json.dump({"available terms": available_terms}, f)
787
+ self.available_terms = available_terms
788
+ return available_terms
789
+
790
+ def load_or_prepare_joint_npmi(self, subgroup_pair):
791
+ """
792
+ Run on-the fly, while the app is already open,
793
+ as it depends on the subgroup terms that the user chooses
794
+ :param subgroup_pair:
795
+ :return:
796
+ """
797
+ # Canonical ordering for subgroup_list
798
+ subgroup_pair = sorted(subgroup_pair)
799
+ subgroup1 = subgroup_pair[0]
800
+ subgroup2 = subgroup_pair[1]
801
+ subgroups_str = "-".join(subgroup_pair)
802
+ if not isdir(self.pmi_cache_path):
803
+ logs.warning("Creating cache")
804
+ # We need to preprocess everything.
805
+ # This should eventually all go into a prepare_dataset CLI
806
+ mkdir(self.pmi_cache_path)
807
+ joint_npmi_fid = pjoin(self.pmi_cache_path, subgroups_str + "_npmi.csv")
808
+ subgroup_files = define_subgroup_files(subgroup_pair, self.pmi_cache_path)
809
+ # Defines the filenames for the cache files from the selected subgroups.
810
+ # Get as much precomputed data as we can.
811
+ if self.use_cache and exists(joint_npmi_fid):
812
+ # When everything is already computed for the selected subgroups.
813
+ logs.info("Loading cached joint npmi")
814
+ joint_npmi_df = self.load_joint_npmi_df(joint_npmi_fid)
815
+ npmi_display_cols = [
816
+ "npmi-bias",
817
+ subgroup1 + "-npmi",
818
+ subgroup2 + "-npmi",
819
+ subgroup1 + "-count",
820
+ subgroup2 + "-count",
821
+ ]
822
+ joint_npmi_df = joint_npmi_df[npmi_display_cols]
823
+ # When maybe some things have been computed for the selected subgroups.
824
+ else:
825
+ if not self.live:
826
+ logs.info("Preparing new joint npmi")
827
+ joint_npmi_df, subgroup_dict = self.prepare_joint_npmi_df(
828
+ subgroup_pair, subgroup_files
829
+ )
830
+ # Cache new results
831
+ logs.info("Writing out.")
832
+ for subgroup in subgroup_pair:
833
+ write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files)
834
+ with open(joint_npmi_fid, "w+") as f:
835
+ joint_npmi_df.to_csv(f)
836
+ else:
837
+ joint_npmi_df = pd.DataFrame()
838
+ logs.info("The joint npmi df is")
839
+ logs.info(joint_npmi_df)
840
+ return joint_npmi_df
841
+
842
+ def load_joint_npmi_df(self, joint_npmi_fid):
843
+ """
844
+ Reads in a saved dataframe with all of the paired results.
845
+ :param joint_npmi_fid:
846
+ :return: paired results
847
+ """
848
+ with open(joint_npmi_fid, "rb") as f:
849
+ joint_npmi_df = pd.read_csv(f)
850
+ joint_npmi_df = self._set_idx_cols_from_cache(joint_npmi_df)
851
+ return joint_npmi_df.dropna()
852
+
853
+ def prepare_joint_npmi_df(self, subgroup_pair, subgroup_files):
854
+ """
855
+ Computs the npmi bias based on the given subgroups.
856
+ Handles cases where some of the selected subgroups have cached nPMI
857
+ computations, but other's don't, computing everything afresh if there
858
+ are not cached files.
859
+ :param subgroup_pair:
860
+ :return: Dataframe with nPMI for the words, nPMI bias between the words.
861
+ """
862
+ subgroup_dict = {}
863
+ # When npmi is computed for some (but not all) of subgroup_list
864
+ for subgroup in subgroup_pair:
865
+ logs.info("Load or failing...")
866
+ # When subgroup npmi has been computed in a prior session.
867
+ cached_results = self.load_or_fail_cached_npmi_scores(
868
+ subgroup, subgroup_files[subgroup]
869
+ )
870
+ # If the function did not return False and we did find it, use.
871
+ if cached_results:
872
+ # FYI: subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df = cached_results
873
+ # Holds the previous sessions' data for use in this session.
874
+ subgroup_dict[subgroup] = cached_results
875
+ logs.info("Calculating for subgroup list")
876
+ joint_npmi_df, subgroup_dict = self.do_npmi(subgroup_pair, subgroup_dict)
877
+ return joint_npmi_df.dropna(), subgroup_dict
878
+
879
+ # TODO: Update pairwise assumption
880
+ def do_npmi(self, subgroup_pair, subgroup_dict):
881
+ """
882
+ Calculates nPMI for given identity terms and the nPMI bias between.
883
+ :param subgroup_pair: List of identity terms to calculate the bias for
884
+ :return: Subset of data for the UI
885
+ :return: Selected identity term's co-occurrence counts with
886
+ other words, pmi per word, and nPMI per word.
887
+ """
888
+ logs.info("Initializing npmi class")
889
+ npmi_obj = self.set_npmi_obj()
890
+ # Canonical ordering used
891
+ subgroup_pair = tuple(sorted(subgroup_pair))
892
+ # Calculating nPMI statistics
893
+ for subgroup in subgroup_pair:
894
+ # If the subgroup data is already computed, grab it.
895
+ # TODO: Should we set idx and column names similarly to how we set them for cached files?
896
+ if subgroup not in subgroup_dict:
897
+ logs.info("Calculating statistics for %s" % subgroup)
898
+ vocab_cooc_df, pmi_df, npmi_df = npmi_obj.calc_metrics(subgroup)
899
+ # Store the nPMI information for the current subgroups
900
+ subgroup_dict[subgroup] = (vocab_cooc_df, pmi_df, npmi_df)
901
+ # Pair the subgroups together, indexed by all words that
902
+ # co-occur between them.
903
+ logs.info("Computing pairwise npmi bias")
904
+ paired_results = npmi_obj.calc_paired_metrics(subgroup_pair, subgroup_dict)
905
+ UI_results = make_npmi_fig(paired_results, subgroup_pair)
906
+ return UI_results, subgroup_dict
907
+
908
+ def set_npmi_obj(self):
909
+ """
910
+ Initializes the nPMI class with the given words and tokenized sentences.
911
+ :return:
912
+ """
913
+ npmi_obj = nPMI(self.dstats.vocab_counts_df, self.dstats.tokenized_df)
914
+ return npmi_obj
915
+
916
+ def load_or_fail_cached_npmi_scores(self, subgroup, subgroup_fids):
917
+ """
918
+ Reads cached scores from the specified subgroup files
919
+ :param subgroup: string of the selected identity term
920
+ :return:
921
+ """
922
+ # TODO: Ordering of npmi, pmi, vocab triple should be consistent
923
+ subgroup_npmi_fid, subgroup_pmi_fid, subgroup_cooc_fid = subgroup_fids
924
+ if (
925
+ exists(subgroup_npmi_fid)
926
+ and exists(subgroup_pmi_fid)
927
+ and exists(subgroup_cooc_fid)
928
+ ):
929
+ logs.info("Reading in pmi data....")
930
+ with open(subgroup_cooc_fid, "rb") as f:
931
+ subgroup_cooc_df = pd.read_csv(f)
932
+ logs.info("pmi")
933
+ with open(subgroup_pmi_fid, "rb") as f:
934
+ subgroup_pmi_df = pd.read_csv(f)
935
+ logs.info("npmi")
936
+ with open(subgroup_npmi_fid, "rb") as f:
937
+ subgroup_npmi_df = pd.read_csv(f)
938
+ subgroup_cooc_df = self._set_idx_cols_from_cache(
939
+ subgroup_cooc_df, subgroup, "count"
940
+ )
941
+ subgroup_pmi_df = self._set_idx_cols_from_cache(
942
+ subgroup_pmi_df, subgroup, "pmi"
943
+ )
944
+ subgroup_npmi_df = self._set_idx_cols_from_cache(
945
+ subgroup_npmi_df, subgroup, "npmi"
946
+ )
947
+ return subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df
948
+ return False
949
+
950
+ def _set_idx_cols_from_cache(self, csv_df, subgroup=None, calc_str=None):
951
+ """
952
+ Helps make sure all of the read-in files can be accessed within code
953
+ via standardized indices and column names.
954
+ :param csv_df:
955
+ :param subgroup:
956
+ :param calc_str:
957
+ :return:
958
+ """
959
+ # The csv saves with this column instead of the index, so that's weird.
960
+ if "Unnamed: 0" in csv_df.columns:
961
+ csv_df = csv_df.set_index("Unnamed: 0")
962
+ csv_df.index.name = WORD
963
+ elif WORD in csv_df.columns:
964
+ csv_df = csv_df.set_index(WORD)
965
+ csv_df.index.name = WORD
966
+ elif VOCAB in csv_df.columns:
967
+ csv_df = csv_df.set_index(VOCAB)
968
+ csv_df.index.name = WORD
969
+ if subgroup and calc_str:
970
+ csv_df.columns = [subgroup + "-" + calc_str]
971
+ elif subgroup:
972
+ csv_df.columns = [subgroup]
973
+ elif calc_str:
974
+ csv_df.columns = [calc_str]
975
+ return csv_df
976
+
977
+ def get_available_terms(self):
978
+ return self.load_or_prepare_npmi_terms()
979
+
980
+
981
+ def dummy(doc):
982
+ return doc
983
+
984
+
985
+ def count_vocab_frequencies(tokenized_df):
986
+ """
987
+ Based on an input pandas DataFrame with a 'text' column,
988
+ this function will count the occurrences of all words.
989
+ :return: [num_words x num_sentences] DataFrame with the rows corresponding to the
990
+ different vocabulary words and the column to the presence (0 or 1) of that word.
991
+ """
992
+
993
+ cvec = CountVectorizer(
994
+ tokenizer=dummy,
995
+ preprocessor=dummy,
996
+ )
997
+ # We do this to calculate per-word statistics
998
+ # Fast calculation of single word counts
999
+ logs.info(
1000
+ "Fitting dummy tokenization to make matrix using the previous tokenization"
1001
+ )
1002
+ cvec.fit(tokenized_df[TOKENIZED_FIELD])
1003
+ document_matrix = cvec.transform(tokenized_df[TOKENIZED_FIELD])
1004
+ batches = np.linspace(0, tokenized_df.shape[0], _NUM_VOCAB_BATCHES).astype(int)
1005
+ i = 0
1006
+ tf = []
1007
+ while i < len(batches) - 1:
1008
+ logs.info("%s of %s vocab batches" % (str(i), str(len(batches))))
1009
+ batch_result = np.sum(
1010
+ document_matrix[batches[i] : batches[i + 1]].toarray(), axis=0
1011
+ )
1012
+ tf.append(batch_result)
1013
+ i += 1
1014
+ word_count_df = pd.DataFrame(
1015
+ [np.sum(tf, axis=0)], columns=cvec.get_feature_names()
1016
+ ).transpose()
1017
+ # Now organize everything into the dataframes
1018
+ word_count_df.columns = [CNT]
1019
+ word_count_df.index.name = WORD
1020
+ return word_count_df
1021
+
1022
+
1023
+ def calc_p_word(word_count_df):
1024
+ # p(word)
1025
+ word_count_df[PROP] = word_count_df[CNT] / float(sum(word_count_df[CNT]))
1026
+ vocab_counts_df = pd.DataFrame(word_count_df.sort_values(by=CNT, ascending=False))
1027
+ vocab_counts_df[VOCAB] = vocab_counts_df.index
1028
+ return vocab_counts_df
1029
+
1030
+
1031
+ def filter_vocab(vocab_counts_df):
1032
+ # TODO: Add warnings (which words are missing) to log file?
1033
+ filtered_vocab_counts_df = vocab_counts_df.drop(_CLOSED_CLASS, errors="ignore")
1034
+ filtered_count = filtered_vocab_counts_df[CNT]
1035
+ filtered_count_denom = float(sum(filtered_vocab_counts_df[CNT]))
1036
+ filtered_vocab_counts_df[PROP] = filtered_count / filtered_count_denom
1037
+ return filtered_vocab_counts_df
1038
+
1039
+
1040
+ ## Figures ##
1041
+
1042
+
1043
+ def write_plotly(fig, fid):
1044
+ write_json(plotly.io.to_json(fig), fid)
1045
+
1046
+
1047
+ def read_plotly(fid):
1048
+ fig = plotly.io.from_json(json.load(open(fid, encoding="utf-8")))
1049
+ return fig
1050
+
1051
+
1052
+ def make_fig_lengths(tokenized_df, length_field):
1053
+ fig_tok_length, axs = plt.subplots(figsize=(15, 6), dpi=150)
1054
+ sns.histplot(data=tokenized_df[length_field], kde=True, bins=100, ax=axs)
1055
+ sns.rugplot(data=tokenized_df[length_field], ax=axs)
1056
+ return fig_tok_length
1057
+
1058
+
1059
+ def make_fig_labels(label_df, label_names, label_field):
1060
+ labels = label_df[label_field].unique()
1061
+ label_sums = [len(label_df[label_df[label_field] == label]) for label in labels]
1062
+ fig_labels = px.pie(label_df, values=label_sums, names=label_names)
1063
+ return fig_labels
1064
+
1065
+
1066
+ def make_zipf_fig_ranked_word_list(vocab_df, unique_counts, unique_ranks):
1067
+ ranked_words = {}
1068
+ for count, rank in zip(unique_counts, unique_ranks):
1069
+ vocab_df[vocab_df[CNT] == count]["rank"] = rank
1070
+ ranked_words[rank] = ",".join(
1071
+ vocab_df[vocab_df[CNT] == count].index.astype(str)
1072
+ ) # Use the hovertext kw argument for hover text
1073
+ ranked_words_list = [wrds for rank, wrds in sorted(ranked_words.items())]
1074
+ return ranked_words_list
1075
+
1076
+
1077
+ def make_npmi_fig(paired_results, subgroup_pair):
1078
+ subgroup1, subgroup2 = subgroup_pair
1079
+ UI_results = pd.DataFrame()
1080
+ if "npmi-bias" in paired_results:
1081
+ UI_results["npmi-bias"] = paired_results["npmi-bias"].astype(float)
1082
+ UI_results[subgroup1 + "-npmi"] = paired_results["npmi"][
1083
+ subgroup1 + "-npmi"
1084
+ ].astype(float)
1085
+ UI_results[subgroup1 + "-count"] = paired_results["count"][
1086
+ subgroup1 + "-count"
1087
+ ].astype(int)
1088
+ if subgroup1 != subgroup2:
1089
+ UI_results[subgroup2 + "-npmi"] = paired_results["npmi"][
1090
+ subgroup2 + "-npmi"
1091
+ ].astype(float)
1092
+ UI_results[subgroup2 + "-count"] = paired_results["count"][
1093
+ subgroup2 + "-count"
1094
+ ].astype(int)
1095
+ return UI_results.sort_values(by="npmi-bias", ascending=True)
1096
+
1097
+
1098
+ def make_zipf_fig(vocab_counts_df, z):
1099
+ zipf_counts = z.calc_zipf_counts(vocab_counts_df)
1100
+ unique_counts = z.uniq_counts
1101
+ unique_ranks = z.uniq_ranks
1102
+ ranked_words_list = make_zipf_fig_ranked_word_list(
1103
+ vocab_counts_df, unique_counts, unique_ranks
1104
+ )
1105
+ zmin = z.get_xmin()
1106
+ logs.info("zipf counts is")
1107
+ logs.info(zipf_counts)
1108
+ layout = go.Layout(xaxis=dict(range=[0, 100]))
1109
+ fig = go.Figure(
1110
+ data=[
1111
+ go.Bar(
1112
+ x=z.uniq_ranks,
1113
+ y=z.uniq_counts,
1114
+ hovertext=ranked_words_list,
1115
+ name="Word Rank Frequency",
1116
+ )
1117
+ ],
1118
+ layout=layout,
1119
+ )
1120
+ fig.add_trace(
1121
+ go.Scatter(
1122
+ x=z.uniq_ranks[zmin : len(z.uniq_ranks)],
1123
+ y=zipf_counts[zmin : len(z.uniq_ranks)],
1124
+ hovertext=ranked_words_list[zmin : len(z.uniq_ranks)],
1125
+ line=go.scatter.Line(color="crimson", width=3),
1126
+ name="Zipf Predicted Frequency",
1127
+ )
1128
+ )
1129
+ # Customize aspect
1130
+ # fig.update_traces(marker_color='limegreen',
1131
+ # marker_line_width=1.5, opacity=0.6)
1132
+ fig.update_layout(title_text="Word Counts, Observed and Predicted by Zipf")
1133
+ fig.update_layout(xaxis_title="Word Rank")
1134
+ fig.update_layout(yaxis_title="Frequency")
1135
+ fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.10))
1136
+ return fig
1137
+
1138
+
1139
+ ## Input/Output ###
1140
+
1141
+
1142
+ def define_subgroup_files(subgroup_list, pmi_cache_path):
1143
+ """
1144
+ Sets the file ids for the input identity terms
1145
+ :param subgroup_list: List of identity terms
1146
+ :return:
1147
+ """
1148
+ subgroup_files = {}
1149
+ for subgroup in subgroup_list:
1150
+ # TODO: Should the pmi, npmi, and count just be one file?
1151
+ subgroup_npmi_fid = pjoin(pmi_cache_path, subgroup + "_npmi.csv")
1152
+ subgroup_pmi_fid = pjoin(pmi_cache_path, subgroup + "_pmi.csv")
1153
+ subgroup_cooc_fid = pjoin(pmi_cache_path, subgroup + "_vocab_cooc.csv")
1154
+ subgroup_files[subgroup] = (
1155
+ subgroup_npmi_fid,
1156
+ subgroup_pmi_fid,
1157
+ subgroup_cooc_fid,
1158
+ )
1159
+ return subgroup_files
1160
+
1161
+
1162
+ ## Input/Output ##
1163
+
1164
+
1165
+ def intersect_dfs(df_dict):
1166
+ started = 0
1167
+ new_df = None
1168
+ for key, df in df_dict.items():
1169
+ if df is None:
1170
+ continue
1171
+ for key2, df2 in df_dict.items():
1172
+ if df2 is None:
1173
+ continue
1174
+ if key == key2:
1175
+ continue
1176
+ if started:
1177
+ new_df = new_df.join(df2, how="inner", lsuffix="1", rsuffix="2")
1178
+ else:
1179
+ new_df = df.join(df2, how="inner", lsuffix="1", rsuffix="2")
1180
+ started = 1
1181
+ return new_df.copy()
1182
+
1183
+
1184
+ def write_df(df, df_fid):
1185
+ feather.write_feather(df, df_fid)
1186
+
1187
+
1188
+ def write_json(json_dict, json_fid):
1189
+ with open(json_fid, "w", encoding="utf-8") as f:
1190
+ json.dump(json_dict, f)
1191
+
1192
+
1193
+ def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
1194
+ """
1195
+ Saves the calculated nPMI statistics to their output files.
1196
+ Includes the npmi scores for each identity term, the pmi scores, and the
1197
+ co-occurrence counts of the identity term with all the other words
1198
+ :param subgroup: Identity term
1199
+ :return:
1200
+ """
1201
+ subgroup_fids = subgroup_files[subgroup]
1202
+ subgroup_npmi_fid, subgroup_pmi_fid, subgroup_cooc_fid = subgroup_fids
1203
+ subgroup_dfs = subgroup_dict[subgroup]
1204
+ subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df = subgroup_dfs
1205
+ with open(subgroup_npmi_fid, "w+") as f:
1206
+ subgroup_npmi_df.to_csv(f)
1207
+ with open(subgroup_pmi_fid, "w+") as f:
1208
+ subgroup_pmi_df.to_csv(f)
1209
+ with open(subgroup_cooc_fid, "w+") as f:
1210
+ subgroup_cooc_df.to_csv(f)
1211
+
1212
+
1213
+ def write_zipf_data(z, zipf_fid):
1214
+ zipf_dict = {}
1215
+ zipf_dict["xmin"] = int(z.xmin)
1216
+ zipf_dict["xmax"] = int(z.xmax)
1217
+ zipf_dict["alpha"] = float(z.alpha)
1218
+ zipf_dict["ks_distance"] = float(z.distance)
1219
+ zipf_dict["p-value"] = float(z.ks_test.pvalue)
1220
+ zipf_dict["uniq_counts"] = [int(count) for count in z.uniq_counts]
1221
+ zipf_dict["uniq_ranks"] = [int(rank) for rank in z.uniq_ranks]
1222
+ with open(zipf_fid, "w+", encoding="utf-8") as f:
1223
+ json.dump(zipf_dict, f)
data_measurements/dataset_utils.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ from dataclasses import asdict
17
+ from os.path import exists
18
+
19
+ import pandas as pd
20
+ from datasets import Dataset, get_dataset_infos, load_dataset, load_from_disk
21
+
22
+ # treating inf values as NaN as well
23
+ pd.set_option("use_inf_as_na", True)
24
+
25
+ ## String names used in Hugging Face dataset configs.
26
+ HF_FEATURE_FIELD = "features"
27
+ HF_LABEL_FIELD = "label"
28
+ HF_DESC_FIELD = "description"
29
+
30
+ CACHE_DIR = "cache_dir"
31
+ ## String names we are using within this code.
32
+ # These are not coming from the stored dataset nor HF config,
33
+ # but rather used as identifiers in our dicts and dataframes.
34
+ OUR_TEXT_FIELD = "text"
35
+ OUR_LABEL_FIELD = "label"
36
+ TOKENIZED_FIELD = "tokenized_text"
37
+ EMBEDDING_FIELD = "embedding"
38
+ LENGTH_FIELD = "length"
39
+ VOCAB = "vocab"
40
+ WORD = "word"
41
+ CNT = "count"
42
+ PROP = "proportion"
43
+ TEXT_NAN_CNT = "text_nan_count"
44
+ TXT_LEN = "text lengths"
45
+ DEDUP_TOT = "dedup_total"
46
+ TOT_WORDS = "total words"
47
+ TOT_OPEN_WORDS = "total open words"
48
+
49
+ _DATASET_LIST = [
50
+ "c4",
51
+ "squad",
52
+ "squad_v2",
53
+ "hate_speech18",
54
+ "hate_speech_offensive",
55
+ "glue",
56
+ "super_glue",
57
+ "wikitext",
58
+ "imdb",
59
+ "HuggingFaceM4/OBELICS",
60
+ ]
61
+
62
+ _STREAMABLE_DATASET_LIST = [
63
+ "c4",
64
+ "wikitext",
65
+ "HuggingFaceM4/OBELICS",
66
+ ]
67
+
68
+ _MAX_ROWS = 100
69
+
70
+
71
+ def load_truncated_dataset(
72
+ dataset_name,
73
+ config_name,
74
+ split_name,
75
+ num_rows=_MAX_ROWS,
76
+ cache_name=None,
77
+ use_cache=True,
78
+ use_streaming=True,
79
+ ):
80
+ """
81
+ This function loads the first `num_rows` items of a dataset for a
82
+ given `config_name` and `split_name`.
83
+ If `cache_name` exists, the truncated dataset is loaded from `cache_name`.
84
+ Otherwise, a new truncated dataset is created and immediately saved
85
+ to `cache_name`.
86
+ When the dataset is streamable, we iterate through the first
87
+ `num_rows` examples in streaming mode, write them to a jsonl file,
88
+ then create a new dataset from the json.
89
+ This is the most direct way to make a Dataset from an IterableDataset
90
+ as of datasets version 1.6.1.
91
+ Otherwise, we download the full dataset and select the first
92
+ `num_rows` items
93
+ Args:
94
+ dataset_name (string):
95
+ dataset id in the dataset library
96
+ config_name (string):
97
+ dataset configuration
98
+ split_name (string):
99
+ split name
100
+ num_rows (int):
101
+ number of rows to truncate the dataset to
102
+ cache_name (string):
103
+ name of the cache directory
104
+ use_cache (bool):
105
+ whether to load form the cache if it exists
106
+ use_streaming (bool):
107
+ whether to use streaming when the dataset supports it
108
+ Returns:
109
+ Dataset: the truncated dataset as a Dataset object
110
+ """
111
+ if cache_name is None:
112
+ cache_name = f"{dataset_name}_{config_name}_{split_name}_{num_rows}"
113
+ if exists(cache_name):
114
+ dataset = load_from_disk(cache_name)
115
+ else:
116
+ if use_streaming and dataset_name in _STREAMABLE_DATASET_LIST:
117
+ iterable_dataset = load_dataset(
118
+ dataset_name,
119
+ name=config_name,
120
+ split=split_name,
121
+ streaming=True,
122
+ ).take(num_rows)
123
+ rows = list(iterable_dataset)
124
+ f = open("temp.jsonl", "w", encoding="utf-8")
125
+ for row in rows:
126
+ _ = f.write(json.dumps(row) + "\n")
127
+ f.close()
128
+ dataset = Dataset.from_json(
129
+ "temp.jsonl", features=iterable_dataset.features, split=split_name
130
+ )
131
+ else:
132
+ full_dataset = load_dataset(
133
+ dataset_name,
134
+ name=config_name,
135
+ split=split_name,
136
+ )
137
+ dataset = full_dataset.select(range(num_rows))
138
+ dataset.save_to_disk(cache_name)
139
+ return dataset
140
+
141
+
142
+ def intersect_dfs(df_dict):
143
+ started = 0
144
+ new_df = None
145
+ for key, df in df_dict.items():
146
+ if df is None:
147
+ continue
148
+ for key2, df2 in df_dict.items():
149
+ if df2 is None:
150
+ continue
151
+ if key == key2:
152
+ continue
153
+ if started:
154
+ new_df = new_df.join(df2, how="inner", lsuffix="1", rsuffix="2")
155
+ else:
156
+ new_df = df.join(df2, how="inner", lsuffix="1", rsuffix="2")
157
+ started = 1
158
+ return new_df.copy()
159
+
160
+
161
+ def get_typed_features(features, ftype="string", parents=None):
162
+ """
163
+ Recursively get a list of all features of a certain dtype
164
+ :param features:
165
+ :param ftype:
166
+ :param parents:
167
+ :return: a list of tuples > e.g. ('A', 'B', 'C') for feature example['A']['B']['C']
168
+ """
169
+ if parents is None:
170
+ parents = []
171
+ typed_features = []
172
+ for name, feat in features.items():
173
+ if isinstance(feat, dict):
174
+ if feat.get("dtype", None) == ftype or feat.get("feature", {}).get(
175
+ ("dtype", None) == ftype
176
+ ):
177
+ typed_features += [tuple(parents + [name])]
178
+ elif "feature" in feat:
179
+ if feat["feature"].get("dtype", None) == ftype:
180
+ typed_features += [tuple(parents + [name])]
181
+ elif isinstance(feat["feature"], dict):
182
+ typed_features += get_typed_features(
183
+ feat["feature"], ftype, parents + [name]
184
+ )
185
+ else:
186
+ for k, v in feat.items():
187
+ if isinstance(v, dict):
188
+ typed_features += get_typed_features(
189
+ v, ftype, parents + [name, k]
190
+ )
191
+ elif name == "dtype" and feat == ftype:
192
+ typed_features += [tuple(parents)]
193
+ return typed_features
194
+
195
+
196
+ def get_label_features(features, parents=None):
197
+ """
198
+ Recursively get a list of all features that are ClassLabels
199
+ :param features:
200
+ :param parents:
201
+ :return: pairs of tuples as above and the list of class names
202
+ """
203
+ if parents is None:
204
+ parents = []
205
+ label_features = []
206
+ for name, feat in features.items():
207
+ if isinstance(feat, dict):
208
+ if "names" in feat:
209
+ label_features += [(tuple(parents + [name]), feat["names"])]
210
+ elif "feature" in feat:
211
+ if "names" in feat:
212
+ label_features += [
213
+ (tuple(parents + [name]), feat["feature"]["names"])
214
+ ]
215
+ elif isinstance(feat["feature"], dict):
216
+ label_features += get_label_features(
217
+ feat["feature"], parents + [name]
218
+ )
219
+ else:
220
+ for k, v in feat.items():
221
+ if isinstance(v, dict):
222
+ label_features += get_label_features(v, parents + [name, k])
223
+ elif name == "names":
224
+ label_features += [(tuple(parents), feat)]
225
+ return label_features
226
+
227
+
228
+ # get the info we need for the app sidebar in dict format
229
+ def dictionarize_info(dset_info):
230
+ info_dict = asdict(dset_info)
231
+ res = {
232
+ "config_name": info_dict["config_name"],
233
+ "splits": {
234
+ spl: 100 #spl_info["num_examples"]
235
+ for spl, spl_info in info_dict["splits"].items()
236
+ },
237
+ "features": {
238
+ "string": get_typed_features(info_dict["features"], "string"),
239
+ "int32": get_typed_features(info_dict["features"], "int32"),
240
+ "float32": get_typed_features(info_dict["features"], "float32"),
241
+ "label": get_label_features(info_dict["features"]),
242
+ },
243
+ "description": dset_info.description,
244
+ }
245
+ return res
246
+
247
+
248
+ def get_dataset_info_dicts(dataset_id=None):
249
+ """
250
+ Creates a dict from dataset configs.
251
+ Uses the datasets lib's get_dataset_infos
252
+ :return: Dictionary mapping dataset names to their configurations
253
+ """
254
+ if dataset_id != None:
255
+ ds_name_to_conf_dict = {
256
+ dataset_id: {
257
+ config_name: dictionarize_info(config_info)
258
+ for config_name, config_info in get_dataset_infos(dataset_id).items()
259
+ }
260
+ }
261
+ else:
262
+ ds_name_to_conf_dict = {
263
+ ds_id: {
264
+ config_name: dictionarize_info(config_info)
265
+ for config_name, config_info in get_dataset_infos(ds_id).items()
266
+ }
267
+ for ds_id in _DATASET_LIST
268
+ }
269
+ return ds_name_to_conf_dict
270
+
271
+
272
+ # get all instances of a specific field in a dataset
273
+ def extract_field(examples, field_path, new_field_name=None):
274
+ if new_field_name is None:
275
+ new_field_name = "_".join(field_path)
276
+ field_list = []
277
+ # TODO: Breaks the CLI if this isn't checked.
278
+ if isinstance(field_path, str):
279
+ field_path = [field_path]
280
+ item_list = examples[field_path[0]]
281
+ for field_name in field_path[1:]:
282
+ item_list = [
283
+ next_item
284
+ for item in item_list
285
+ for next_item in (
286
+ item[field_name]
287
+ if isinstance(item[field_name], list)
288
+ else [item[field_name]]
289
+ )
290
+ ]
291
+ field_list += [
292
+ field
293
+ for item in item_list
294
+ for field in (item if isinstance(item, list) else [item])
295
+ ]
296
+ return {new_field_name: field_list}
data_measurements/embeddings.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from os.path import exists
17
+ from os.path import join as pjoin
18
+
19
+ import plotly.graph_objects as go
20
+ import torch
21
+ import transformers
22
+ from datasets import load_from_disk
23
+ from plotly.io import read_json
24
+ from tqdm import tqdm
25
+
26
+ from .dataset_utils import EMBEDDING_FIELD
27
+
28
+
29
+ def sentence_mean_pooling(model_output, attention_mask):
30
+ """Mean pooling of token embeddings for a sentence."""
31
+ token_embeddings = model_output[
32
+ 0
33
+ ] # First element of model_output contains all token embeddings
34
+ input_mask_expanded = (
35
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
36
+ )
37
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
38
+ input_mask_expanded.sum(1), min=1e-9
39
+ )
40
+
41
+
42
+ class Embeddings:
43
+ def __init__(
44
+ self,
45
+ dstats=None,
46
+ text_dset=None,
47
+ text_field_name="text",
48
+ cache_path="",
49
+ use_cache=False,
50
+ ):
51
+ """Item embeddings and clustering"""
52
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
53
+ self.model_name = "sentence-transformers/all-mpnet-base-v2"
54
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
55
+ self.model = transformers.AutoModel.from_pretrained(self.model_name).to(
56
+ self.device
57
+ )
58
+ self.text_dset = text_dset if dstats is None else dstats.text_dset
59
+ self.text_field_name = (
60
+ text_field_name if dstats is None else dstats.our_text_field
61
+ )
62
+ self.cache_path = cache_path if dstats is None else dstats.cache_path
63
+ self.embeddings_dset_fid = pjoin(self.cache_path, "embeddings_dset")
64
+ self.embeddings_dset = None
65
+ self.node_list_fid = pjoin(self.cache_path, "node_list.th")
66
+ self.node_list = None
67
+ self.nid_map = None
68
+ self.fig_tree_fid = pjoin(self.cache_path, "node_figure.json")
69
+ self.fig_tree = None
70
+ self.cached_clusters = {}
71
+ self.use_cache = use_cache
72
+
73
+ def compute_sentence_embeddings(self, sentences):
74
+ """
75
+ Takes a list of sentences and computes their embeddings
76
+ using self.tokenizer and self.model (with output dimension D)
77
+ followed by mean pooling of the token representations and normalization
78
+ Args:
79
+ sentences ([string]): list of N input sentences
80
+ Returns:
81
+ torch.Tensor: sentence embeddings, dimension NxD
82
+ """
83
+ batch = self.tokenizer(
84
+ sentences, padding=True, truncation=True, return_tensors="pt"
85
+ )
86
+ batch = {k: v.to(self.device) for k, v in batch.items()}
87
+ with torch.no_grad():
88
+ model_output = self.model(**batch)
89
+ sentence_embeds = sentence_mean_pooling(
90
+ model_output, batch["attention_mask"]
91
+ )
92
+ sentence_embeds /= sentence_embeds.norm(dim=-1, keepdim=True)
93
+ return sentence_embeds
94
+
95
+ def make_embeddings(self):
96
+ """
97
+ Batch computes the embeddings of the Dataset self.text_dset,
98
+ using the field self.text_field_name as input.
99
+ Returns:
100
+ Dataset: HF dataset object with a single EMBEDDING_FIELD field
101
+ corresponding to the embeddings (list of floats)
102
+ """
103
+
104
+ def batch_embed_sentences(sentences):
105
+ return {
106
+ EMBEDDING_FIELD: [
107
+ embed.tolist()
108
+ for embed in self.compute_sentence_embeddings(
109
+ sentences[self.text_field_name]
110
+ )
111
+ ]
112
+ }
113
+
114
+ self.embeddings_dset = self.text_dset.map(
115
+ batch_embed_sentences,
116
+ batched=True,
117
+ batch_size=32,
118
+ remove_columns=[self.text_field_name],
119
+ )
120
+
121
+ return self.embeddings_dset
122
+
123
+ def make_text_embeddings(self):
124
+ """Load embeddings dataset from cache or compute it."""
125
+ if self.use_cache and exists(self.embeddings_dset_fid):
126
+ self.embeddings_dset = load_from_disk(self.embeddings_dset_fid)
127
+ else:
128
+ self.embeddings_dset = self.make_embeddings()
129
+ self.embeddings_dset.save_to_disk(self.embeddings_dset_fid)
130
+
131
+ def make_hierarchical_clustering(
132
+ self,
133
+ batch_size=1000,
134
+ approx_neighbors=1000,
135
+ min_cluster_size=10,
136
+ ):
137
+ if self.use_cache and exists(self.node_list_fid):
138
+ self.node_list, self.nid_map = torch.load(self.node_list_fid)
139
+ else:
140
+ self.make_text_embeddings()
141
+ embeddings = torch.Tensor(self.embeddings_dset[EMBEDDING_FIELD])
142
+ self.node_list = fast_cluster(
143
+ embeddings, batch_size, approx_neighbors, min_cluster_size
144
+ )
145
+ self.nid_map = dict(
146
+ [(node["nid"], nid) for nid, node in enumerate(self.node_list)]
147
+ )
148
+ torch.save((self.node_list, self.nid_map), self.node_list_fid)
149
+ print(exists(self.fig_tree_fid), self.fig_tree_fid)
150
+ if self.use_cache and exists(self.fig_tree_fid):
151
+ self.fig_tree = read_json(self.fig_tree_fid)
152
+ else:
153
+ self.fig_tree = make_tree_plot(
154
+ self.node_list, self.nid_map, self.text_dset, self.text_field_name
155
+ )
156
+ self.fig_tree.write_json(self.fig_tree_fid)
157
+
158
+ def find_cluster_beam(self, sentence, beam_size=20):
159
+ """
160
+ This function finds the `beam_size` leaf clusters that are closest to the
161
+ proposed sentence and returns the full path from the root to the cluster
162
+ along with the dot product between the sentence embedding and the
163
+ cluster centroid
164
+ Args:
165
+ sentence (string): input sentence for which to find clusters
166
+ beam_size (int): this is a beam size algorithm to explore the tree
167
+ Returns:
168
+ [([int], float)]: list of (path_from_root, score) sorted by score
169
+ """
170
+ embed = self.compute_sentence_embeddings([sentence])[0].to("cpu")
171
+ active_paths = [([0], torch.dot(embed, self.node_list[0]["centroid"]).item())]
172
+ finished_paths = []
173
+ children_ids_list = [
174
+ [
175
+ self.nid_map[nid]
176
+ for nid in self.node_list[path[-1]]["children_ids"]
177
+ if nid in self.nid_map
178
+ ]
179
+ for path, score in active_paths
180
+ ]
181
+ while len(active_paths) > 0:
182
+ next_ids = sorted(
183
+ [
184
+ (
185
+ beam_id,
186
+ nid,
187
+ torch.dot(embed, self.node_list[nid]["centroid"]).item(),
188
+ )
189
+ for beam_id, children_ids in enumerate(children_ids_list)
190
+ for nid in children_ids
191
+ ],
192
+ key=lambda x: x[2],
193
+ reverse=True,
194
+ )[:beam_size]
195
+ paths = [
196
+ (active_paths[beam_id][0] + [next_id], score)
197
+ for beam_id, next_id, score in next_ids
198
+ ]
199
+ active_paths = []
200
+ for path, score in paths:
201
+ if (
202
+ len(
203
+ [
204
+ nid
205
+ for nid in self.node_list[path[-1]]["children_ids"]
206
+ if nid in self.nid_map
207
+ ]
208
+ )
209
+ > 0
210
+ ):
211
+ active_paths += [(path, score)]
212
+ else:
213
+ finished_paths += [(path, score)]
214
+ children_ids_list = [
215
+ [
216
+ self.nid_map[nid]
217
+ for nid in self.node_list[path[-1]]["children_ids"]
218
+ if nid in self.nid_map
219
+ ]
220
+ for path, score in active_paths
221
+ ]
222
+ return sorted(
223
+ finished_paths,
224
+ key=lambda x: x[-1],
225
+ reverse=True,
226
+ )[:beam_size]
227
+
228
+
229
+ def prepare_merges(embeddings, batch_size=1000, approx_neighbors=1000, low_thres=0.5):
230
+ """
231
+ Prepares an initial list of merges for hierarchical
232
+ clustering. First compute the `approx_neighbors` nearest neighbors,
233
+ then propose a merge for any two points that are closer than `low_thres`
234
+
235
+ Note that if a point has more than `approx_neighbors` neighbors
236
+ closer than `low_thres`, this approach will miss some of those merges
237
+
238
+ Args:
239
+ embeddings (toch.Tensor): Tensor of sentence embeddings - dimension NxD
240
+ batch_size (int): compute nearest neighbors of `batch_size` points at a time
241
+ approx_neighbors (int): only keep `approx_neighbors` nearest neighbors of a point
242
+ low_thres (float): only return merges where the dot product is greater than `low_thres`
243
+ Returns:
244
+ torch.LongTensor: proposed merges ([i, j] with i>j) - dimension: Mx2
245
+ torch.Tensor: merge scores - dimension M
246
+ """
247
+ top_idx_pre = torch.cat(
248
+ [torch.LongTensor(range(embeddings.shape[0]))[:, None]] * batch_size, dim=1
249
+ )
250
+ top_val_all = torch.Tensor(0, approx_neighbors)
251
+ top_idx_all = torch.LongTensor(0, approx_neighbors)
252
+ n_batches = math.ceil(len(embeddings) / batch_size)
253
+ for b in tqdm(range(n_batches)):
254
+ # TODO: batch across second dimension
255
+ cos_scores = torch.mm(
256
+ embeddings[b * batch_size : (b + 1) * batch_size], embeddings.t()
257
+ )
258
+ for i in range(cos_scores.shape[0]):
259
+ cos_scores[i, (b * batch_size) + i :] = -1
260
+ top_val_large, top_idx_large = cos_scores.topk(
261
+ k=approx_neighbors, dim=-1, largest=True
262
+ )
263
+ top_val_all = torch.cat([top_val_all, top_val_large], dim=0)
264
+ top_idx_all = torch.cat([top_idx_all, top_idx_large], dim=0)
265
+ max_neighbor_dist = top_val_large[:, -1].max().item()
266
+ if max_neighbor_dist > low_thres:
267
+ print(
268
+ f"WARNING: with the current set of neireast neighbor, the farthest is {max_neighbor_dist}"
269
+ )
270
+
271
+ all_merges = torch.cat(
272
+ [
273
+ top_idx_pre[top_val_all > low_thres][:, None],
274
+ top_idx_all[top_val_all > low_thres][:, None],
275
+ ],
276
+ dim=1,
277
+ )
278
+ all_merge_scores = top_val_all[top_val_all > low_thres]
279
+
280
+ return (all_merges, all_merge_scores)
281
+
282
+
283
+ def merge_nodes(nodes, current_thres, previous_thres, all_merges, all_merge_scores):
284
+ """
285
+ Merge all nodes if the max dot product between any of their descendants
286
+ is greater than current_thres.
287
+
288
+ Args:
289
+ nodes ([dict]): list of dicts representing the current set of nodes
290
+ current_thres (float): merge all nodes closer than current_thres
291
+ previous_thres (float): nodes closer than previous_thres are already merged
292
+ all_merges (torch.LongTensor): proposed merges ([i, j] with i>j) - dimension: Mx2
293
+ all_merge_scores (torch.Tensor): merge scores - dimension M
294
+ Returns:
295
+ [dict]: extended list with the newly created internal nodes
296
+ """
297
+ merge_ids = (all_merge_scores <= previous_thres) * (
298
+ all_merge_scores > current_thres
299
+ )
300
+ if merge_ids.sum().item() > 0:
301
+ merges = all_merges[merge_ids]
302
+ for a, b in merges.tolist():
303
+ node_a = nodes[a]
304
+ while node_a["parent_id"] != -1:
305
+ node_a = nodes[node_a["parent_id"]]
306
+ node_b = nodes[b]
307
+ while node_b["parent_id"] != -1:
308
+ node_b = nodes[node_b["parent_id"]]
309
+ if node_a["nid"] == node_b["nid"]:
310
+ continue
311
+ else:
312
+ # merge if threshold allows
313
+ if (node_a["depth"] + node_b["depth"]) > 0 and min(
314
+ node_a["merge_threshold"], node_b["merge_threshold"]
315
+ ) == current_thres:
316
+ merge_to = None
317
+ merge_from = None
318
+ if node_a["nid"] < node_b["nid"]:
319
+ merge_from = node_a
320
+ merge_to = node_b
321
+ if node_a["nid"] > node_b["nid"]:
322
+ merge_from = node_b
323
+ merge_to = node_a
324
+ merge_to["depth"] = max(merge_to["depth"], merge_from["depth"])
325
+ merge_to["weight"] += merge_from["weight"]
326
+ merge_to["children_ids"] += (
327
+ merge_from["children_ids"]
328
+ if merge_from["depth"] > 0
329
+ else [merge_from["nid"]]
330
+ )
331
+ for cid in merge_from["children_ids"]:
332
+ nodes[cid]["parent_id"] = merge_to["nid"]
333
+ merge_from["parent_id"] = merge_to["nid"]
334
+ # else new node
335
+ else:
336
+ new_nid = len(nodes)
337
+ new_node = {
338
+ "nid": new_nid,
339
+ "parent_id": -1,
340
+ "depth": max(node_a["depth"], node_b["depth"]) + 1,
341
+ "weight": node_a["weight"] + node_b["weight"],
342
+ "children": [],
343
+ "children_ids": [node_a["nid"], node_b["nid"]],
344
+ "example_ids": [],
345
+ "merge_threshold": current_thres,
346
+ }
347
+ node_a["parent_id"] = new_nid
348
+ node_b["parent_id"] = new_nid
349
+ nodes += [new_node]
350
+ return nodes
351
+
352
+
353
+ def finalize_node(node, nodes, min_cluster_size):
354
+ """Post-process nodes to sort children by descending weight,
355
+ get full list of leaves in the sub-tree, and direct links
356
+ to the cildren nodes, then recurses to all children.
357
+
358
+ Nodes with fewer than `min_cluster_size` descendants are collapsed
359
+ into a single leaf.
360
+ """
361
+ node["children"] = sorted(
362
+ [
363
+ finalize_node(nodes[cid], nodes, min_cluster_size)
364
+ for cid in node["children_ids"]
365
+ ],
366
+ key=lambda x: x["weight"],
367
+ reverse=True,
368
+ )
369
+ if node["depth"] > 0:
370
+ node["example_ids"] = [
371
+ eid for child in node["children"] for eid in child["example_ids"]
372
+ ]
373
+ node["children"] = [
374
+ child for child in node["children"] if child["weight"] >= min_cluster_size
375
+ ]
376
+ assert node["weight"] == len(node["example_ids"]), print(node)
377
+ return node
378
+
379
+
380
+ def fast_cluster(
381
+ embeddings,
382
+ batch_size=1000,
383
+ approx_neighbors=1000,
384
+ min_cluster_size=10,
385
+ low_thres=0.5,
386
+ ):
387
+ """
388
+ Computes an approximate hierarchical clustering based on example
389
+ embeddings. The join criterion is min clustering, i.e. two clusters
390
+ are joined if any pair of their descendants are closer than a threshold
391
+
392
+ The approximate comes from the fact that only the `approx_neighbors` nearest
393
+ neighbors of an example are considered for merges
394
+ """
395
+ batch_size = min(embeddings.shape[0], batch_size)
396
+ all_merges, all_merge_scores = prepare_merges(
397
+ embeddings, batch_size, approx_neighbors, low_thres
398
+ )
399
+ # prepare leaves
400
+ nodes = [
401
+ {
402
+ "nid": nid,
403
+ "parent_id": -1,
404
+ "depth": 0,
405
+ "weight": 1,
406
+ "children": [],
407
+ "children_ids": [],
408
+ "example_ids": [nid],
409
+ "merge_threshold": 1.0,
410
+ }
411
+ for nid in range(embeddings.shape[0])
412
+ ]
413
+ # one level per threshold range
414
+ for i in range(10):
415
+ p_thres = 1 - i * 0.05
416
+ c_thres = 0.95 - i * 0.05
417
+ nodes = merge_nodes(nodes, c_thres, p_thres, all_merges, all_merge_scores)
418
+ # make root
419
+ root_children = [
420
+ node
421
+ for node in nodes
422
+ if node["parent_id"] == -1 and node["weight"] >= min_cluster_size
423
+ ]
424
+ root = {
425
+ "nid": len(nodes),
426
+ "parent_id": -1,
427
+ "depth": max([node["depth"] for node in root_children]) + 1,
428
+ "weight": sum([node["weight"] for node in root_children]),
429
+ "children": [],
430
+ "children_ids": [node["nid"] for node in root_children],
431
+ "example_ids": [],
432
+ "merge_threshold": -1.0,
433
+ }
434
+ nodes += [root]
435
+ for node in root_children:
436
+ node["parent_id"] = root["nid"]
437
+ # finalize tree
438
+ tree = finalize_node(root, nodes, min_cluster_size)
439
+ node_list = []
440
+
441
+ def rec_map_nodes(node, node_list):
442
+ node_list += [node]
443
+ for child in node["children"]:
444
+ rec_map_nodes(child, node_list)
445
+
446
+ rec_map_nodes(tree, node_list)
447
+ # get centroids and distances
448
+ for node in node_list:
449
+ node_embeds = embeddings[node["example_ids"]]
450
+ node["centroid"] = node_embeds.sum(dim=0)
451
+ node["centroid"] /= node["centroid"].norm()
452
+ node["centroid_dot_prods"] = torch.mv(node_embeds, node["centroid"])
453
+ node["sorted_examples_centroid"] = sorted(
454
+ [
455
+ (eid, edp.item())
456
+ for eid, edp in zip(node["example_ids"], node["centroid_dot_prods"])
457
+ ],
458
+ key=lambda x: x[1],
459
+ reverse=True,
460
+ )
461
+ return node_list
462
+
463
+
464
+ def make_tree_plot(node_list, nid_map, text_dset, text_field_name):
465
+ """
466
+ Makes a graphical representation of the tree encoded
467
+ in node-list. The hover label for each node shows the number
468
+ of descendants and the 5 examples that are closest to the centroid
469
+ """
470
+ for nid, node in enumerate(node_list):
471
+ # get list of
472
+ node_examples = {}
473
+ for sid, score in node["sorted_examples_centroid"]:
474
+ node_examples[text_dset[sid][text_field_name]] = score
475
+ if len(node_examples) >= 5:
476
+ break
477
+ node["label"] = node.get(
478
+ "label",
479
+ f"{nid:2d} - {node['weight']:5d} items <br>"
480
+ + "<br>".join(
481
+ [
482
+ f" {score:.2f} > {txt[:64]}" + ("..." if len(txt) >= 63 else "")
483
+ for txt, score in node_examples.items()
484
+ ]
485
+ ),
486
+ )
487
+
488
+ # make plot nodes
489
+ labels = [node["label"] for node in node_list]
490
+
491
+ root = node_list[0]
492
+ root["X"] = 0
493
+ root["Y"] = 0
494
+
495
+ def rec_make_coordinates(node):
496
+ total_weight = 0
497
+ add_weight = len(node["example_ids"]) - sum(
498
+ [child["weight"] for child in node["children"]]
499
+ )
500
+ for child in node["children"]:
501
+ child["X"] = node["X"] + total_weight
502
+ child["Y"] = node["Y"] - 1
503
+ total_weight += child["weight"] + add_weight / len(node["children"])
504
+ rec_make_coordinates(child)
505
+
506
+ rec_make_coordinates(root)
507
+
508
+ E = [] # list of edges
509
+ Xn = []
510
+ Yn = []
511
+ Xe = []
512
+ Ye = []
513
+ for nid, node in enumerate(node_list):
514
+ Xn += [node["X"]]
515
+ Yn += [node["Y"]]
516
+ for child in node["children"]:
517
+ E += [(nid, nid_map[child["nid"]])]
518
+ Xe += [node["X"], child["X"], None]
519
+ Ye += [node["Y"], child["Y"], None]
520
+
521
+ # make figure
522
+ fig = go.Figure()
523
+ fig.add_trace(
524
+ go.Scatter(
525
+ x=Xe,
526
+ y=Ye,
527
+ mode="lines",
528
+ line=dict(color="rgb(210,210,210)", width=1),
529
+ hoverinfo="none",
530
+ )
531
+ )
532
+ fig.add_trace(
533
+ go.Scatter(
534
+ x=Xn,
535
+ y=Yn,
536
+ mode="markers",
537
+ name="nodes",
538
+ marker=dict(
539
+ symbol="circle-dot",
540
+ size=18,
541
+ color="#6175c1",
542
+ line=dict(color="rgb(50,50,50)", width=1)
543
+ # '#DB4551',
544
+ ),
545
+ text=labels,
546
+ hoverinfo="text",
547
+ opacity=0.8,
548
+ )
549
+ )
550
+ return fig
data_measurements/npmi.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import warnings
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+ from sklearn.preprocessing import MultiLabelBinarizer
22
+
23
+ # Might be nice to print to log instead? Happens when we drop closed class.
24
+ warnings.filterwarnings(action="ignore", category=UserWarning)
25
+ # When we divide by 0 in log
26
+ np.seterr(divide="ignore")
27
+
28
+ # treating inf values as NaN as well
29
+ pd.set_option("use_inf_as_na", True)
30
+
31
+ logs = logging.getLogger(__name__)
32
+ logs.setLevel(logging.INFO)
33
+ logs.propagate = False
34
+
35
+ if not logs.handlers:
36
+
37
+ Path("./log_files").mkdir(exist_ok=True)
38
+
39
+ # Logging info to log file
40
+ file = logging.FileHandler("./log_files/npmi.log")
41
+ fileformat = logging.Formatter("%(asctime)s:%(message)s")
42
+ file.setLevel(logging.INFO)
43
+ file.setFormatter(fileformat)
44
+
45
+ # Logging debug messages to stream
46
+ stream = logging.StreamHandler()
47
+ streamformat = logging.Formatter("[data_measurements_tool] %(message)s")
48
+ stream.setLevel(logging.WARNING)
49
+ stream.setFormatter(streamformat)
50
+
51
+ logs.addHandler(file)
52
+ logs.addHandler(stream)
53
+
54
+ _NUM_BATCHES = 500
55
+
56
+
57
+ class nPMI:
58
+ # TODO: Expand beyond pairwise
59
+ def __init__(
60
+ self,
61
+ vocab_counts_df,
62
+ tokenized_df,
63
+ tokenized_col_name="tokenized_text",
64
+ num_batches=_NUM_BATCHES,
65
+ ):
66
+ logs.info("Initiating npmi class.")
67
+ logs.info("vocab is")
68
+ logs.info(vocab_counts_df)
69
+ self.vocab_counts_df = vocab_counts_df
70
+ logs.info("tokenized is")
71
+ self.tokenized_df = tokenized_df
72
+ logs.info(self.tokenized_df)
73
+ self.tokenized_col_name = tokenized_col_name
74
+ # self.mlb_list holds num batches x num_sentences
75
+ self.mlb_list = []
76
+
77
+ def binarize_words_in_sentence(self):
78
+ logs.info("Creating co-occurrence matrix for PMI calculations.")
79
+ batches = np.linspace(0, self.tokenized_df.shape[0], _NUM_BATCHES).astype(int)
80
+ i = 0
81
+ # Creates list of size (# batches x # sentences)
82
+ while i < len(batches) - 1:
83
+ # Makes a sparse matrix (shape: # sentences x # words),
84
+ # with the occurrence of each word per sentence.
85
+ mlb = MultiLabelBinarizer(classes=self.vocab_counts_df.index)
86
+ logs.info(
87
+ "%s of %s sentence binarize batches." % (str(i), str(len(batches)))
88
+ )
89
+ # Returns series: batch size x num_words
90
+ mlb_series = mlb.fit_transform(
91
+ self.tokenized_df[self.tokenized_col_name][batches[i] : batches[i + 1]]
92
+ )
93
+ i += 1
94
+ self.mlb_list.append(mlb_series)
95
+
96
+ def calc_cooccurrences(self, subgroup, subgroup_idx):
97
+ initialize = True
98
+ coo_df = None
99
+ # Big computation here! Should only happen once.
100
+ logs.info(
101
+ "Approaching big computation! Here, we binarize all words in the sentences, making a sparse matrix of sentences."
102
+ )
103
+ if not self.mlb_list:
104
+ self.binarize_words_in_sentence()
105
+ for batch_id in range(len(self.mlb_list)):
106
+ logs.info(
107
+ "%s of %s co-occurrence count batches"
108
+ % (str(batch_id), str(len(self.mlb_list)))
109
+ )
110
+ # List of all the sentences (list of vocab) in that batch
111
+ batch_sentence_row = self.mlb_list[batch_id]
112
+ # Dataframe of # sentences in batch x vocabulary size
113
+ sent_batch_df = pd.DataFrame(batch_sentence_row)
114
+ # logs.info('sent batch df is')
115
+ # logs.info(sent_batch_df)
116
+ # Subgroup counts per-sentence for the given batch
117
+ subgroup_df = sent_batch_df[subgroup_idx]
118
+ subgroup_df.columns = [subgroup]
119
+ # Remove the sentences where the count of the subgroup is 0.
120
+ # This way we have less computation & resources needs.
121
+ subgroup_df = subgroup_df[subgroup_df > 0]
122
+ logs.info("Removing 0 counts, subgroup_df is")
123
+ logs.info(subgroup_df)
124
+ mlb_subgroup_only = sent_batch_df[sent_batch_df[subgroup_idx] > 0]
125
+ logs.info("mlb subgroup only is")
126
+ logs.info(mlb_subgroup_only)
127
+ # Create cooccurrence matrix for the given subgroup and all words.
128
+ logs.info("Now we do the T.dot approach for co-occurrences")
129
+ batch_coo_df = pd.DataFrame(mlb_subgroup_only.T.dot(subgroup_df))
130
+
131
+ # Creates a batch-sized dataframe of co-occurrence counts.
132
+ # Note these could just be summed rather than be batch size.
133
+ if initialize:
134
+ coo_df = batch_coo_df
135
+ else:
136
+ coo_df = coo_df.add(batch_coo_df, fill_value=0)
137
+ logs.info("coo_df is")
138
+ logs.info(coo_df)
139
+ initialize = False
140
+ logs.info("Returning co-occurrence matrix")
141
+ logs.info(coo_df)
142
+ return pd.DataFrame(coo_df)
143
+
144
+ def calc_paired_metrics(self, subgroup_pair, subgroup_npmi_dict):
145
+ """
146
+ Calculates nPMI metrics between paired subgroups.
147
+ Special handling for a subgroup paired with itself.
148
+ :param subgroup_npmi_dict:
149
+ :return:
150
+ """
151
+ paired_results_dict = {"npmi": {}, "pmi": {}, "count": {}}
152
+ # Canonical ordering. This is done previously, but just in case...
153
+ subgroup1, subgroup2 = sorted(subgroup_pair)
154
+ vocab_cooc_df1, pmi_df1, npmi_df1 = subgroup_npmi_dict[subgroup1]
155
+ logs.info("vocab cooc")
156
+ logs.info(vocab_cooc_df1)
157
+ if subgroup1 == subgroup2:
158
+ shared_npmi_df = npmi_df1
159
+ shared_pmi_df = pmi_df1
160
+ shared_vocab_cooc_df = vocab_cooc_df1
161
+ else:
162
+ vocab_cooc_df2, pmi_df2, npmi_df2 = subgroup_npmi_dict[subgroup2]
163
+ logs.info("vocab cooc2")
164
+ logs.info(vocab_cooc_df2)
165
+ # Note that lsuffix and rsuffix should not come into play.
166
+ shared_npmi_df = npmi_df1.join(
167
+ npmi_df2, how="inner", lsuffix="1", rsuffix="2"
168
+ )
169
+ shared_pmi_df = pmi_df1.join(pmi_df2, how="inner", lsuffix="1", rsuffix="2")
170
+ shared_vocab_cooc_df = vocab_cooc_df1.join(
171
+ vocab_cooc_df2, how="inner", lsuffix="1", rsuffix="2"
172
+ )
173
+ shared_vocab_cooc_df = shared_vocab_cooc_df.dropna()
174
+ shared_vocab_cooc_df = shared_vocab_cooc_df[
175
+ shared_vocab_cooc_df.index.notnull()
176
+ ]
177
+ logs.info("shared npmi df")
178
+ logs.info(shared_npmi_df)
179
+ logs.info("shared vocab df")
180
+ logs.info(shared_vocab_cooc_df)
181
+ npmi_bias = (
182
+ shared_npmi_df[subgroup1 + "-npmi"] - shared_npmi_df[subgroup2 + "-npmi"]
183
+ )
184
+ paired_results_dict["npmi-bias"] = npmi_bias.dropna()
185
+ paired_results_dict["npmi"] = shared_npmi_df.dropna()
186
+ paired_results_dict["pmi"] = shared_pmi_df.dropna()
187
+ paired_results_dict["count"] = shared_vocab_cooc_df.dropna()
188
+ return paired_results_dict
189
+
190
+ def calc_metrics(self, subgroup):
191
+ # Index of the subgroup word in the sparse vector
192
+ subgroup_idx = self.vocab_counts_df.index.get_loc(subgroup)
193
+ logs.info("Calculating co-occurrences...")
194
+ df_coo = self.calc_cooccurrences(subgroup, subgroup_idx)
195
+ vocab_cooc_df = self.set_idx_cols(df_coo, subgroup)
196
+ logs.info(vocab_cooc_df)
197
+ logs.info("Calculating PMI...")
198
+ pmi_df = self.calc_PMI(vocab_cooc_df, subgroup)
199
+ logs.info(pmi_df)
200
+ logs.info("Calculating nPMI...")
201
+ npmi_df = self.calc_nPMI(pmi_df, vocab_cooc_df, subgroup)
202
+ logs.info(npmi_df)
203
+ return vocab_cooc_df, pmi_df, npmi_df
204
+
205
+ def set_idx_cols(self, df_coo, subgroup):
206
+ """
207
+ :param df_coo: Co-occurrence counts for subgroup, length is num_words
208
+ :return:
209
+ """
210
+ count_df = df_coo.set_index(self.vocab_counts_df.index)
211
+ count_df.columns = [subgroup + "-count"]
212
+ count_df[subgroup + "-count"] = count_df[subgroup + "-count"].astype(int)
213
+ return count_df
214
+
215
+ def calc_PMI(self, vocab_cooc_df, subgroup):
216
+ """
217
+ # PMI(x;y) = h(y) - h(y|x)
218
+ # = h(subgroup) - h(subgroup|word)
219
+ # = log (p(subgroup|word) / p(subgroup))
220
+ # nPMI additionally divides by -log(p(x,y)) = -log(p(x|y)p(y))
221
+ """
222
+ # Calculation of p(subgroup)
223
+ subgroup_prob = self.vocab_counts_df.loc[subgroup]["proportion"]
224
+ # Calculation of p(subgroup|word) = count(subgroup,word) / count(word)
225
+ # Because the inidices match (the vocab words),
226
+ # this division doesn't need to specify the index (I think?!)
227
+ p_subgroup_g_word = (
228
+ vocab_cooc_df[subgroup + "-count"] / self.vocab_counts_df["count"]
229
+ )
230
+ logs.info("p_subgroup_g_word is")
231
+ logs.info(p_subgroup_g_word)
232
+ pmi_df = pd.DataFrame()
233
+ pmi_df[subgroup + "-pmi"] = np.log(p_subgroup_g_word / subgroup_prob)
234
+ # Note: A potentially faster solution for adding count, npmi,
235
+ # can be based on this zip idea:
236
+ # df_test['size_kb'], df_test['size_mb'], df_test['size_gb'] =
237
+ # zip(*df_test['size'].apply(sizes))
238
+ return pmi_df.dropna()
239
+
240
+ def calc_nPMI(self, pmi_df, vocab_cooc_df, subgroup):
241
+ """
242
+ # nPMI additionally divides by -log(p(x,y)) = -log(p(x|y)p(y))
243
+ # = -log(p(word|subgroup)p(word))
244
+ """
245
+ p_word_g_subgroup = vocab_cooc_df[subgroup + "-count"] / sum(
246
+ vocab_cooc_df[subgroup + "-count"]
247
+ )
248
+ p_word = pmi_df.apply(
249
+ lambda x: self.vocab_counts_df.loc[x.name]["proportion"], axis=1
250
+ )
251
+ normalize_pmi = -np.log(p_word_g_subgroup * p_word)
252
+ npmi_df = pd.DataFrame()
253
+ npmi_df[subgroup + "-npmi"] = pmi_df[subgroup + "-pmi"] / normalize_pmi
254
+ return npmi_df.dropna()
data_measurements/streamlit_utils.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import statistics
16
+
17
+ import json
18
+ import pandas as pd
19
+ import seaborn as sns
20
+ import streamlit as st
21
+ #from st_aggrid import AgGrid, GridOptionsBuilder
22
+
23
+ from .dataset_utils import HF_DESC_FIELD, HF_FEATURE_FIELD, HF_LABEL_FIELD
24
+ st.set_option('deprecation.showPyplotGlobalUse', False)
25
+ json_file_path = "cache_dir/has_cache.json"
26
+ with open(json_file_path, "r", encoding="utf-8") as j:
27
+ _HAS_CACHE = json.loads(j.read())
28
+
29
+ def sidebar_header():
30
+ st.sidebar.markdown(
31
+ """
32
+ This demo showcases the [dataset measures as we develop them](https://huggingface.co/blog/data-measurements-tool).
33
+ Right now this has a few pre-loaded datasets for which you can:
34
+ - view some general statistics about the text vocabulary, lengths, labels
35
+ - explore some distributional statistics to assess properties of the language
36
+ - view some comparison statistics and overview of the text distribution
37
+
38
+ The tool is in development, and will keep growing in utility and functionality 🤗🚧
39
+ """,
40
+ unsafe_allow_html=True,
41
+ )
42
+
43
+
44
+ def sidebar_selection(ds_name_to_dict, column_id):
45
+ # ds_names = list(ds_name_to_dict.keys())
46
+ ds_names = list(_HAS_CACHE.keys())
47
+ with st.sidebar.expander(f"Choose dataset and field {column_id}", expanded=True):
48
+ # choose a dataset to analyze
49
+ ds_name = st.selectbox(
50
+ f"Choose dataset to explore{column_id}:",
51
+ ds_names,
52
+ index=ds_names.index("hate_speech18"),
53
+ )
54
+ # choose a config to analyze
55
+ ds_configs = ds_name_to_dict[ds_name]
56
+ if ds_name == "c4":
57
+ config_names = ['en','en.noblocklist','realnewslike']
58
+ else:
59
+ config_names = list(ds_configs.keys())
60
+ config_names = list(_HAS_CACHE[ds_name].keys())
61
+ config_name = st.selectbox(
62
+ f"Choose configuration{column_id}:",
63
+ config_names,
64
+ index=0,
65
+ )
66
+ # choose a subset of num_examples
67
+ # TODO: Handling for multiple text features
68
+ #ds_config = ds_configs[config_name]
69
+ # text_features = ds_config[HF_FEATURE_FIELD]["string"]
70
+ text_features = [tuple(text_field.split('-')) for text_field in _HAS_CACHE[ds_name][config_name]]
71
+ # TODO @yacine: Explain what this is doing and why eg tp[0] could = "id"
72
+ text_field = st.selectbox(
73
+ f"Which text feature from the{column_id} dataset would you like to analyze?",
74
+ [("text",)]
75
+ if ds_name == "c4"
76
+ else [tp for tp in text_features if tp[0] != "id"],
77
+ )
78
+ # Choose a split and dataset size
79
+ # avail_splits = list(ds_config["splits"].keys())
80
+ avail_splits = list(_HAS_CACHE[ds_name][config_name]['-'.join(text_field)].keys())
81
+ # 12.Nov note: Removing "test" because those should not be examined
82
+ # without discussion of pros and cons, which we haven't done yet.
83
+ if "test" in avail_splits:
84
+ avail_splits.remove("test")
85
+ split = st.selectbox(
86
+ f"Which split from the{column_id} dataset would you like to analyze?",
87
+ avail_splits,
88
+ index=0,
89
+ )
90
+ label_field, label_names = (
91
+ ds_name_to_dict[ds_name][config_name][HF_FEATURE_FIELD][HF_LABEL_FIELD][0]
92
+ if len(
93
+ ds_name_to_dict[ds_name][config_name][HF_FEATURE_FIELD][HF_LABEL_FIELD]
94
+ )
95
+ > 0
96
+ else ((), [])
97
+ )
98
+ return {
99
+ "dset_name": ds_name,
100
+ "dset_config": config_name,
101
+ "split_name": split,
102
+ "text_field": text_field,
103
+ "label_field": label_field,
104
+ "label_names": label_names,
105
+ }
106
+
107
+
108
+ def expander_header(dstats, ds_name_to_dict, column_id):
109
+ with st.expander(f"Dataset Description{column_id}"):
110
+ st.markdown(
111
+ ds_name_to_dict[dstats.dset_name][dstats.dset_config][HF_DESC_FIELD]
112
+ )
113
+ st.dataframe(dstats.dset_peek)
114
+
115
+
116
+ def expander_general_stats(dstats, column_id):
117
+ with st.expander(f"General Text Statistics{column_id}"):
118
+ st.caption(
119
+ "Use this widget to check whether the terms you see most represented"
120
+ " in the dataset make sense for the goals of the dataset."
121
+ )
122
+ if dstats.total_words == 0:
123
+ st.markdown("Eh oh...not finding the file I need. 😭 Probably it will be there soon. 🤞 Check back later!")
124
+ else:
125
+ st.markdown("There are {0} total words".format(str(dstats.total_words)))
126
+ st.markdown(
127
+ "There are {0} words after removing closed "
128
+ "class words".format(str(dstats.total_open_words))
129
+ )
130
+ st.markdown(
131
+ "The most common "
132
+ "[open class words](https://dictionary.apa.org/open-class-words) "
133
+ "and their counts are: "
134
+ )
135
+ st.dataframe(dstats.sorted_top_vocab_df)
136
+ st.markdown(
137
+ "There are {0} missing values in the dataset.".format(
138
+ str(dstats.text_nan_count)
139
+ )
140
+ )
141
+ if dstats.dedup_total > 0:
142
+ st.markdown(
143
+ "There are {0} duplicate items in the dataset. "
144
+ "For more information about the duplicates, "
145
+ "click the 'Duplicates' tab below.".format(str(dstats.dedup_total))
146
+ )
147
+ else:
148
+ st.markdown("There are 0 duplicate items in the dataset. ")
149
+
150
+
151
+ ### Show the label distribution from the datasets
152
+ def expander_label_distribution(fig_labels, column_id):
153
+ with st.expander(f"Label Distribution{column_id}", expanded=False):
154
+ st.caption(
155
+ "Use this widget to see how balanced the labels in your dataset are."
156
+ )
157
+ if fig_labels is not None:
158
+ st.plotly_chart(fig_labels, use_container_width=True)
159
+ else:
160
+ st.markdown("No labels were found in the dataset")
161
+
162
+
163
+ def expander_text_lengths(dstats, column_id):
164
+ _TEXT_LENGTH_CAPTION = (
165
+ "Use this widget to identify outliers, particularly suspiciously long outliers."
166
+ )
167
+ with st.expander(f"Text Lengths{column_id}", expanded=False):
168
+ st.caption(_TEXT_LENGTH_CAPTION)
169
+ st.markdown(
170
+ "Below, you can see how the lengths of the text instances in your dataset are distributed."
171
+ )
172
+ st.markdown(
173
+ "Any unexpected peaks or valleys in the distribution may help to identify instances you want to remove or augment."
174
+ )
175
+ st.markdown(
176
+ "### Here is the relative frequency of different text lengths in your dataset:"
177
+ )
178
+ try:
179
+ st.image(dstats.fig_tok_length_png)
180
+ except:
181
+ st.pyplot(dstats.fig_tok_length, use_container_width=True)
182
+ st.markdown(
183
+ "The average length of text instances is **"
184
+ + str(dstats.avg_length)
185
+ + " words**, with a standard deviation of **"
186
+ + str(dstats.std_length)
187
+ + "**."
188
+ )
189
+ # This is quite a large file and is breaking our ability to navigate the app development.
190
+ # Just passing if it's not already there for launch v0
191
+ if dstats.length_df is not None:
192
+ start_id_show_lengths = st.selectbox(
193
+ "Show examples of length:",
194
+ sorted(dstats.length_df["length"].unique().tolist()),
195
+ key=f"select_show_length_{column_id}",
196
+ )
197
+ st.table(
198
+ dstats.length_df[
199
+ dstats.length_df["length"] == start_id_show_lengths
200
+ ].set_index("length")
201
+ )
202
+
203
+
204
+ ### Third, use a sentence embedding model
205
+ def expander_text_embeddings(
206
+ text_dset, fig_tree, node_list, embeddings, text_field, column_id
207
+ ):
208
+ with st.expander(f"Text Embedding Clusters{column_id}", expanded=False):
209
+ _EMBEDDINGS_CAPTION = """
210
+ ### Hierarchical Clustering of Text Fields
211
+ Taking in the diversity of text represented in a dataset can be
212
+ challenging when it is made up of hundreds to thousands of sentences.
213
+ Grouping these text items based on a measure of similarity can help
214
+ users gain some insights into their distribution.
215
+ The following figure shows a hierarchical clustering of the text fields
216
+ in the dataset based on a
217
+ [Sentence-Transformer](https://hf.co/sentence-transformers/all-mpnet-base-v2)
218
+ model. Clusters are merged if any of the embeddings in cluster A has a
219
+ dot product with any of the embeddings or with the centroid of cluster B
220
+ higher than a threshold (one threshold per level, from 0.5 to 0.95).
221
+ To explore the clusters, you can:
222
+ - hover over a node to see the 5 most representative examples (deduplicated)
223
+ - enter an example in the text box below to see which clusters it is most similar to
224
+ - select a cluster by ID to show all of its examples
225
+ """
226
+ st.markdown(_EMBEDDINGS_CAPTION)
227
+ st.plotly_chart(fig_tree, use_container_width=True)
228
+ st.markdown("---\n")
229
+ if st.checkbox(
230
+ label="Enter text to see nearest clusters",
231
+ key=f"search_clusters_{column_id}",
232
+ ):
233
+ compare_example = st.text_area(
234
+ label="Enter some text here to see which of the clusters in the dataset it is closest to",
235
+ key=f"search_cluster_input_{column_id}",
236
+ )
237
+ if compare_example != "":
238
+ paths_to_leaves = embeddings.cached_clusters.get(
239
+ compare_example,
240
+ embeddings.find_cluster_beam(compare_example, beam_size=50),
241
+ )
242
+ clusters_intro = ""
243
+ if paths_to_leaves[0][1] < 0.3:
244
+ clusters_intro += (
245
+ "**Warning: no close clusters found (best score <0.3). **"
246
+ )
247
+ clusters_intro += "The closest clusters to the text entered aboce are:"
248
+ st.markdown(clusters_intro)
249
+ for path, score in paths_to_leaves[:5]:
250
+ example = text_dset[
251
+ node_list[path[-1]]["sorted_examples_centroid"][0][0]
252
+ ][text_field][:256]
253
+ st.write(
254
+ f"Cluster {path[-1]:5d} | Score: {score:.3f} \n Example: {example}"
255
+ )
256
+ show_node_default = paths_to_leaves[0][0][-1]
257
+ else:
258
+ show_node_default = len(node_list) // 2
259
+ else:
260
+ show_node_default = len(node_list) // 2
261
+ st.markdown("---\n")
262
+ if text_dset is None:
263
+ st.markdown("Missing source text to show, check back later!")
264
+ else:
265
+ show_node = st.selectbox(
266
+ f"Choose a leaf node to explore in the{column_id} dataset:",
267
+ range(len(node_list)),
268
+ index=show_node_default,
269
+ )
270
+ node = node_list[show_node]
271
+ start_id = st.slider(
272
+ f"Show closest sentences in cluster to the centroid{column_id} starting at index:",
273
+ 0,
274
+ len(node["sorted_examples_centroid"]) - 5,
275
+ value=0,
276
+ step=5,
277
+ )
278
+ for sid, sim in node["sorted_examples_centroid"][start_id : start_id + 5]:
279
+ # only show the first 4 lines and the first 10000 characters
280
+ show_text = text_dset[sid][text_field][:10000]
281
+ show_text = "\n".join(show_text.split("\n")[:4])
282
+ st.text(f"{sim:.3f} \t {show_text}")
283
+
284
+
285
+ ### Then, show duplicates
286
+ def expander_text_duplicates(dstats, column_id):
287
+ # TODO: Saving/loading figure
288
+ with st.expander(f"Text Duplicates{column_id}", expanded=False):
289
+ st.caption(
290
+ "Use this widget to identify text strings that appear more than once."
291
+ )
292
+ st.markdown(
293
+ "A model's training and testing may be negatively affected by unwarranted duplicates ([Lee et al., 2021](https://arxiv.org/abs/2107.06499))."
294
+ )
295
+ st.markdown("------")
296
+ st.write(
297
+ "### Here is the list of all the duplicated items and their counts in your dataset:"
298
+ )
299
+ if dstats.dup_counts_df is None or dstats.dup_counts_df.empty:
300
+ st.write("There are no duplicates in this dataset! 🥳")
301
+ else:
302
+ st.dataframe(dstats.dup_counts_df.reset_index(drop=True))
303
+
304
+
305
+ def expander_npmi_description(min_vocab):
306
+ _NPMI_CAPTION = (
307
+ "Use this widget to identify problematic biases and stereotypes in your data."
308
+ )
309
+ _NPMI_CAPTION1 = """
310
+ nPMI scores for a word help to identify potentially
311
+ problematic associations, ranked by how close the association is."""
312
+ _NPMI_CAPTION2 = """
313
+ nPMI bias scores for paired words help to identify how word
314
+ associations are skewed between the selected selected words
315
+ ([Aka et al., 2021](https://arxiv.org/abs/2103.03417)).
316
+ """
317
+
318
+ st.caption(_NPMI_CAPTION)
319
+ st.markdown(_NPMI_CAPTION1)
320
+ st.markdown(_NPMI_CAPTION2)
321
+ st.markdown(" ")
322
+ st.markdown(
323
+ "You can select from gender and sexual orientation "
324
+ "identity terms that appear in the dataset at least %s "
325
+ "times." % min_vocab
326
+ )
327
+ st.markdown(
328
+ "The resulting ranked words are those that co-occur with both "
329
+ "identity terms. "
330
+ )
331
+ st.markdown(
332
+ "The more *positive* the score, the more associated the word is with the first identity term. "
333
+ "The more *negative* the score, the more associated the word is with the second identity term."
334
+ )
335
+
336
+
337
+ ### Finally, show Zipf stuff
338
+ def expander_zipf(z, zipf_fig, column_id):
339
+ with st.expander(
340
+ f"Vocabulary Distribution{column_id}: Zipf's Law Fit", expanded=False
341
+ ):
342
+ try:
343
+ _ZIPF_CAPTION = """This shows how close the observed language is to an ideal
344
+ natural language distribution following [Zipf's law](https://en.wikipedia.org/wiki/Zipf%27s_law),
345
+ calculated by minimizing the [Kolmogorov-Smirnov (KS) statistic](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test)."""
346
+
347
+ powerlaw_eq = r"""p(x) \propto x^{- \alpha}"""
348
+ zipf_summary = (
349
+ "The optimal alpha based on this dataset is: **"
350
+ + str(round(z.alpha, 2))
351
+ + "**, with a KS distance of: **"
352
+ + str(round(z.distance, 2))
353
+ )
354
+ zipf_summary += (
355
+ "**. This was fit with a minimum rank value of: **"
356
+ + str(int(z.xmin))
357
+ + "**, which is the optimal rank *beyond which* the scaling regime of the power law fits best."
358
+ )
359
+
360
+ alpha_warning = "Your alpha value is a bit on the high side, which means that the distribution over words in this dataset is a bit unnatural. This could be due to non-language items throughout the dataset."
361
+ xmin_warning = "The minimum rank for this fit is a bit on the high side, which means that the frequencies of your most common words aren't distributed as would be expected by Zipf's law."
362
+ fit_results_table = pd.DataFrame.from_dict(
363
+ {
364
+ r"Alpha:": [str("%.2f" % z.alpha)],
365
+ "KS distance:": [str("%.2f" % z.distance)],
366
+ "Min rank:": [str("%s" % int(z.xmin))],
367
+ },
368
+ columns=["Results"],
369
+ orient="index",
370
+ )
371
+ fit_results_table.index.name = column_id
372
+ st.caption(
373
+ "Use this widget for the counts of different words in your dataset, measuring the difference between the observed count and the expected count under Zipf's law."
374
+ )
375
+ st.markdown(_ZIPF_CAPTION)
376
+ st.write(
377
+ """
378
+ A Zipfian distribution follows the power law: $p(x) \propto x^{-α}$
379
+ with an ideal α value of 1."""
380
+ )
381
+ st.markdown(
382
+ "In general, an alpha greater than 2 or a minimum rank greater than 10 (take with a grain of salt) means that your distribution is relativaly _unnatural_ for natural language. This can be a sign of mixed artefacts in the dataset, such as HTML markup."
383
+ )
384
+ st.markdown(
385
+ "Below, you can see the counts of each word in your dataset vs. the expected number of counts following a Zipfian distribution."
386
+ )
387
+ st.markdown("-----")
388
+ st.write("### Here is your dataset's Zipf results:")
389
+ st.dataframe(fit_results_table)
390
+ st.write(zipf_summary)
391
+ # TODO: Nice UI version of the content in the comments.
392
+ # st.markdown("\nThe KS test p-value is < %.2f" % z.ks_test.pvalue)
393
+ # if z.ks_test.pvalue < 0.01:
394
+ # st.markdown(
395
+ # "\n Great news! Your data fits a powerlaw with a minimum KS " "distance of %.4f" % z.distance)
396
+ # else:
397
+ # st.markdown("\n Sadly, your data does not fit a powerlaw. =(")
398
+ # st.markdown("Checking the goodness of fit of our observed distribution")
399
+ # st.markdown("to the hypothesized power law distribution")
400
+ # st.markdown("using a Kolmogorov–Smirnov (KS) test.")
401
+ st.plotly_chart(zipf_fig, use_container_width=True)
402
+ if z.alpha > 2:
403
+ st.markdown(alpha_warning)
404
+ if z.xmin > 5:
405
+ st.markdown(xmin_warning)
406
+ except:
407
+ st.write("Under construction! 😱 🚧")
408
+
409
+
410
+ ### Finally finally finally, show nPMI stuff.
411
+ def npmi_widget(npmi_stats, min_vocab, column_id):
412
+ """
413
+ Part of the main app, but uses a user interaction so pulled out as its own f'n.
414
+ :param use_cache:
415
+ :param column_id:
416
+ :param npmi_stats:
417
+ :param min_vocab:
418
+ :return:
419
+ """
420
+ with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
421
+ try:
422
+ if len(npmi_stats.available_terms) > 0:
423
+ expander_npmi_description(min_vocab)
424
+ st.markdown("-----")
425
+ term1 = st.selectbox(
426
+ f"What is the first term you want to select?{column_id}",
427
+ npmi_stats.available_terms,
428
+ )
429
+ term2 = st.selectbox(
430
+ f"What is the second term you want to select?{column_id}",
431
+ reversed(npmi_stats.available_terms),
432
+ )
433
+ # We calculate/grab nPMI data based on a canonical (alphabetic)
434
+ # subgroup ordering.
435
+ subgroup_pair = sorted([term1, term2])
436
+ try:
437
+ joint_npmi_df = npmi_stats.load_or_prepare_joint_npmi(subgroup_pair)
438
+ npmi_show(joint_npmi_df)
439
+ except KeyError:
440
+ st.markdown(
441
+ "**WARNING!** The nPMI for these terms has not been pre-computed, please re-run caching."
442
+ )
443
+ else:
444
+ st.markdown(
445
+ "No words found co-occurring with both of the selected identity terms."
446
+ )
447
+ except:
448
+ st.write("Under construction! 😱 🚧")
449
+
450
+
451
+ def npmi_show(paired_results):
452
+ if paired_results.empty:
453
+ st.markdown("No words that co-occur enough times for results! Or there's a 🐛. Or we're still computing this one. 🤷")
454
+ else:
455
+ s = pd.DataFrame(paired_results.sort_values(by="npmi-bias", ascending=True))
456
+ # s.columns=pd.MultiIndex.from_arrays([['npmi','npmi','npmi','count', 'count'],['bias','man','straight','man','straight']])
457
+ s.index.name = "word"
458
+ npmi_cols = s.filter(like="npmi").columns
459
+ count_cols = s.filter(like="count").columns
460
+ if s.shape[0] > 10000:
461
+ bias_thres = max(abs(s["npmi-bias"][5000]), abs(s["npmi-bias"][-5000]))
462
+ print(f"filtering with bias threshold: {bias_thres}")
463
+ s_filtered = s[s["npmi-bias"].abs() > bias_thres]
464
+ else:
465
+ s_filtered = s
466
+ # TODO: This is very different look than the duplicates table above. Should probably standardize.
467
+ cm = sns.palplot(sns.diverging_palette(270, 36, s=99, l=48, n=16))
468
+ out_df = (
469
+ s_filtered.style.background_gradient(subset=npmi_cols, cmap=cm)
470
+ .format(subset=npmi_cols, formatter="{:,.3f}")
471
+ .format(subset=count_cols, formatter=int)
472
+ .set_properties(
473
+ subset=count_cols, **{"width": "10em", "text-align": "center"}
474
+ )
475
+ .set_properties(**{"align": "center"})
476
+ .set_caption(
477
+ "nPMI scores and co-occurence counts between the selected identity terms and the words they both co-occur with"
478
+ )
479
+ ) # s = pd.read_excel("output.xlsx", index_col="word")
480
+ st.write("### Here is your dataset's nPMI results:")
481
+ st.dataframe(out_df)
482
+
483
+
484
+ ### Dumping unused functions here for now
485
+ ### Second, show the distribution of text perplexities
486
+ def expander_text_perplexities(text_label_df, sorted_sents_loss, fig_loss):
487
+ with st.expander("Show text perplexities A", expanded=False):
488
+ st.markdown("### Text perplexities A")
489
+ st.plotly_chart(fig_loss, use_container_width=True)
490
+ start_id_show_loss = st.slider(
491
+ "Show highest perplexity sentences in A starting at index:",
492
+ 0,
493
+ text_label_df.shape[0] - 5,
494
+ value=0,
495
+ step=5,
496
+ )
497
+ for lss, sent in sorted_sents_loss[start_id_show_loss : start_id_show_loss + 5]:
498
+ st.text(f"{lss:.3f} {sent}")
data_measurements/zipf.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+ import powerlaw
21
+ import streamlit as st
22
+ from scipy.stats import ks_2samp
23
+ from scipy.stats import zipf as zipf_lib
24
+
25
+ from .dataset_utils import CNT, PROP
26
+
27
+ # treating inf values as NaN as well
28
+
29
+ pd.set_option("use_inf_as_na", True)
30
+
31
+ logs = logging.getLogger(__name__)
32
+ logs.setLevel(logging.INFO)
33
+ logs.propagate = False
34
+
35
+ if not logs.handlers:
36
+
37
+ Path("./log_files").mkdir(exist_ok=True)
38
+
39
+ # Logging info to log file
40
+ file = logging.FileHandler("./log_files/zipf.log")
41
+ fileformat = logging.Formatter("%(asctime)s:%(message)s")
42
+ file.setLevel(logging.INFO)
43
+ file.setFormatter(fileformat)
44
+
45
+ # Logging debug messages to stream
46
+ stream = logging.StreamHandler()
47
+ streamformat = logging.Formatter("[data_measurements_tool] %(message)s")
48
+ stream.setLevel(logging.WARNING)
49
+ stream.setFormatter(streamformat)
50
+
51
+ logs.addHandler(file)
52
+ logs.addHandler(stream)
53
+
54
+
55
+ class Zipf:
56
+ def __init__(self, vocab_counts_df=pd.DataFrame()):
57
+ self.vocab_counts_df = vocab_counts_df
58
+ self.alpha = None
59
+ self.xmin = None
60
+ self.xmax = None
61
+ self.fit = None
62
+ self.ranked_words = {}
63
+ self.uniq_counts = []
64
+ self.uniq_ranks = []
65
+ self.uniq_fit_counts = None
66
+ self.term_df = None
67
+ self.pvalue = None
68
+ self.ks_test = None
69
+ self.distance = None
70
+ self.fit = None
71
+ self.predicted_zipf_counts = None
72
+ if not self.vocab_counts_df.empty:
73
+ logs.info("Fitting based on input vocab counts.")
74
+ self.calc_fit(vocab_counts_df)
75
+ logs.info("Getting predicted counts.")
76
+ self.predicted_zipf_counts = self.calc_zipf_counts(vocab_counts_df)
77
+
78
+ def load(self, zipf_dict):
79
+ self.set_xmin(zipf_dict["xmin"])
80
+ self.set_xmax(zipf_dict["xmax"])
81
+ self.set_alpha(zipf_dict["alpha"])
82
+ self.set_ks_distance(zipf_dict["ks_distance"])
83
+ self.set_p(zipf_dict["p-value"])
84
+ self.set_unique_ranks(zipf_dict["uniq_ranks"])
85
+ self.set_unique_counts(zipf_dict["uniq_counts"])
86
+
87
+ def calc_fit(self, vocab_counts_df):
88
+ """
89
+ Uses the powerlaw package to fit the observed frequencies to a zipfian distribution.
90
+ We use the KS-distance to fit, as that seems more appropriate that MLE.
91
+ :param vocab_counts_df:
92
+ :return:
93
+ """
94
+ self.vocab_counts_df = vocab_counts_df
95
+ # TODO: These proportions may have already been calculated.
96
+ vocab_counts_df[PROP] = vocab_counts_df[CNT] / float(sum(vocab_counts_df[CNT]))
97
+ rank_column = vocab_counts_df[CNT].rank(
98
+ method="dense", numeric_only=True, ascending=False
99
+ )
100
+ vocab_counts_df["rank"] = rank_column.astype("int64")
101
+ observed_counts = vocab_counts_df[CNT].values
102
+ # Note another method for determining alpha might be defined by
103
+ # (Newman, 2005): alpha = 1 + n * sum(ln( xi / xmin )) ^ -1
104
+ self.fit = powerlaw.Fit(observed_counts, fit_method="KS", discrete=True)
105
+ # This should probably be a pmf (not pdf); using discrete=True above.
106
+ # original_data=False uses only the fitted data (within xmin and xmax).
107
+ # pdf_bin_edges: The portion of the data within the bin.
108
+ # observed_pdf: The probability density function (normalized histogram)
109
+ # of the data.
110
+ pdf_bin_edges, observed_pdf = self.fit.pdf(original_data=False)
111
+ # See the 'Distribution' class described here for info:
112
+ # https://pythonhosted.org/powerlaw/#powerlaw.Fit.pdf
113
+ theoretical_distro = self.fit.power_law
114
+ # The probability density function (normalized histogram) of the
115
+ # theoretical distribution.
116
+ predicted_pdf = theoretical_distro.pdf()
117
+ # !!!! CRITICAL VALUE FOR ZIPF !!!!
118
+ self.alpha = theoretical_distro.alpha
119
+ # Exclusive xmin: The optimal xmin *beyond which* the scaling regime of
120
+ # the power law fits best.
121
+ self.xmin = theoretical_distro.xmin
122
+ self.xmax = theoretical_distro.xmax
123
+ self.distance = theoretical_distro.KS()
124
+ self.ks_test = ks_2samp(observed_pdf, predicted_pdf)
125
+ self.pvalue = self.ks_test[1]
126
+ logs.info("KS test:")
127
+ logs.info(self.ks_test)
128
+
129
+ def set_xmax(self, xmax):
130
+ """
131
+ xmax is usually None, so we add some handling to set it as the
132
+ maximum rank in the dataset.
133
+ :param xmax:
134
+ :return:
135
+ """
136
+ if xmax:
137
+ self.xmax = int(xmax)
138
+ elif self.uniq_counts:
139
+ self.xmax = int(len(self.uniq_counts))
140
+ elif self.uniq_ranks:
141
+ self.xmax = int(len(self.uniq_ranks))
142
+
143
+ def get_xmax(self):
144
+ """
145
+ :return:
146
+ """
147
+ if not self.xmax:
148
+ self.set_xmax(self.xmax)
149
+ return self.xmax
150
+
151
+ def set_p(self, p):
152
+ self.p = int(p)
153
+
154
+ def get_p(self):
155
+ return int(self.p)
156
+
157
+ def set_xmin(self, xmin):
158
+ self.xmin = xmin
159
+
160
+ def get_xmin(self):
161
+ if self.xmin:
162
+ return int(self.xmin)
163
+ return self.xmin
164
+
165
+ def set_alpha(self, alpha):
166
+ self.alpha = float(alpha)
167
+
168
+ def get_alpha(self):
169
+ return float(self.alpha)
170
+
171
+ def set_ks_distance(self, distance):
172
+ self.distance = float(distance)
173
+
174
+ def get_ks_distance(self):
175
+ return self.distance
176
+
177
+ def calc_zipf_counts(self, vocab_counts_df):
178
+ """
179
+ The fit is based on an optimal xmin (minimum rank)
180
+ Let's use this to make count estimates for the zipf fit,
181
+ by multiplying the fitted pmf value by the sum of counts above xmin.
182
+ :return: array of count values following the fitted pmf.
183
+ """
184
+ # TODO: Limit from above xmin to below xmax, not just above xmin.
185
+ counts = vocab_counts_df[CNT]
186
+ self.uniq_counts = list(pd.unique(counts))
187
+ self.uniq_ranks = list(np.arange(1, len(self.uniq_counts) + 1))
188
+ logs.info(self.uniq_counts)
189
+ logs.info(self.xmin)
190
+ logs.info(self.xmax)
191
+ # Makes sure they are ints if not None
192
+ xmin = self.get_xmin()
193
+ xmax = self.get_xmax()
194
+ self.uniq_fit_counts = self.uniq_counts[xmin + 1 : xmax]
195
+ pmf_mass = float(sum(self.uniq_fit_counts))
196
+ zipf_counts = np.array(
197
+ [self.estimate_count(rank, pmf_mass) for rank in self.uniq_ranks]
198
+ )
199
+ return zipf_counts
200
+
201
+ def estimate_count(self, rank, pmf_mass):
202
+ return int(round(zipf_lib.pmf(rank, self.alpha) * pmf_mass))
203
+
204
+ def set_unique_ranks(self, ranks):
205
+ self.uniq_ranks = ranks
206
+
207
+ def get_unique_ranks(self):
208
+ return self.uniq_ranks
209
+
210
+ def get_unique_fit_counts(self):
211
+ return self.uniq_fit_counts
212
+
213
+ def set_unique_counts(self, counts):
214
+ self.uniq_counts = counts
215
+
216
+ def get_unique_counts(self):
217
+ return self.uniq_counts
218
+
219
+ def set_axes(self, unique_counts, unique_ranks):
220
+ self.uniq_counts = unique_counts
221
+ self.uniq_ranks = unique_ranks
222
+
223
+ # TODO: Incorporate this function (not currently using)
224
+ def fit_others(self, fit):
225
+ st.markdown(
226
+ "_Checking log likelihood ratio to see if the data is better explained by other well-behaved distributions..._"
227
+ )
228
+ # The first value returned from distribution_compare is the log likelihood ratio
229
+ better_distro = False
230
+ trunc = fit.distribution_compare("power_law", "truncated_power_law")
231
+ if trunc[0] < 0:
232
+ st.markdown("Seems a truncated power law is a better fit.")
233
+ better_distro = True
234
+
235
+ lognormal = fit.distribution_compare("power_law", "lognormal")
236
+ if lognormal[0] < 0:
237
+ st.markdown("Seems a lognormal distribution is a better fit.")
238
+ st.markdown("But don't panic -- that happens sometimes with language.")
239
+ better_distro = True
240
+
241
+ exponential = fit.distribution_compare("power_law", "exponential")
242
+ if exponential[0] < 0:
243
+ st.markdown("Seems an exponential distribution is a better fit. Panic.")
244
+ better_distro = True
245
+
246
+ if not better_distro:
247
+ st.markdown("\nSeems your data is best fit by a power law. Celebrate!!")
log_files/app.log ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2023-08-23 17:29:50,194:Using Single Dataset Mode
2
+ 2023-08-23 17:29:50,202:Using cache
3
+ 2023-08-23 17:34:04,702:Using Single Dataset Mode
4
+ 2023-08-23 17:43:38,030:Using Single Dataset Mode
5
+ 2023-08-23 17:43:38,035:Using cache
6
+ 2023-08-23 17:45:36,703:Using Single Dataset Mode
7
+ 2023-08-23 17:48:20,572:Using Single Dataset Mode
8
+ 2023-08-23 17:52:30,321:Using Single Dataset Mode
9
+ 2023-08-23 17:54:35,084:Using Single Dataset Mode
10
+ 2023-08-23 17:56:12,155:Using Comparison Mode
11
+ 2023-08-24 07:51:23,364:Using Single Dataset Mode
12
+ 2023-08-24 07:57:23,750:Using Single Dataset Mode
13
+ 2023-08-24 08:01:29,502:Using Single Dataset Mode
14
+ 2023-08-24 08:03:08,131:Using Single Dataset Mode
15
+ 2023-08-24 08:04:51,132:Using Single Dataset Mode
16
+ 2023-08-24 08:04:51,138:Using cache
17
+ 2023-08-24 08:10:10,454:Using Single Dataset Mode
18
+ 2023-08-24 08:15:29,052:Using Single Dataset Mode
19
+ 2023-08-24 08:15:29,060:Using cache
20
+ 2023-08-24 08:17:31,506:Using Single Dataset Mode
21
+ 2023-08-24 08:19:49,714:Using Single Dataset Mode
22
+ 2023-08-24 18:42:47,928:Using Single Dataset Mode
23
+ 2023-08-24 18:46:27,220:Using Single Dataset Mode
24
+ 2023-08-24 18:49:34,812:Using Single Dataset Mode
25
+ 2023-08-24 18:50:59,294:Using Single Dataset Mode
26
+ 2023-08-24 18:52:13,936:Using Single Dataset Mode
27
+ 2023-08-24 18:52:13,942:Using cache
28
+ 2023-08-24 18:53:35,540:Using Single Dataset Mode
29
+ 2023-08-24 18:54:55,961:Using Single Dataset Mode
30
+ 2023-08-24 18:56:59,520:Using Single Dataset Mode
31
+ 2023-08-24 18:58:22,133:Using Single Dataset Mode
32
+ 2023-08-24 19:00:13,836:Using Single Dataset Mode
33
+ 2023-08-24 19:01:23,903:Using Single Dataset Mode
34
+ 2023-08-24 20:23:51,453:Using Single Dataset Mode
35
+ 2023-08-24 20:24:59,017:Using Single Dataset Mode
36
+ 2023-08-24 20:26:46,678:Using Single Dataset Mode
37
+ 2023-08-24 20:27:59,157:Using Single Dataset Mode
38
+ 2023-08-24 20:29:31,861:Using Single Dataset Mode
39
+ 2023-08-24 20:30:48,436:Using Single Dataset Mode
40
+ 2023-08-24 20:33:15,450:Using Single Dataset Mode
41
+ 2023-08-24 20:34:29,544:Using Single Dataset Mode
42
+ 2023-08-25 08:41:31,588:Using Single Dataset Mode
43
+ 2023-08-25 08:42:41,115:Using Single Dataset Mode
44
+ 2023-08-25 08:44:16,584:Using Single Dataset Mode
45
+ 2023-09-26 00:37:43,807:Using Single Dataset Mode
46
+ 2023-09-26 02:26:14,675:Using Single Dataset Mode
47
+ 2023-09-26 02:59:35,715:Using Single Dataset Mode
48
+ 2023-09-26 02:59:35,729:Using cache
49
+ 2023-09-26 03:00:09,840:Using Single Dataset Mode
50
+ 2023-09-26 03:00:09,843:Using cache
51
+ 2023-09-26 03:07:14,181:Using Single Dataset Mode
52
+ 2023-09-26 03:07:14,191:Using cache
53
+ 2023-09-26 03:15:33,456:Using Single Dataset Mode
54
+ 2023-09-26 03:15:33,470:Using cache
55
+ 2023-09-26 03:33:45,719:Using Single Dataset Mode
56
+ 2023-09-26 03:33:45,755:Using cache
57
+ 2023-09-26 03:35:05,699:Using Single Dataset Mode
58
+ 2023-09-26 05:46:30,460:Using Single Dataset Mode
59
+ 2023-09-26 05:46:30,460:Using cache
log_files/dataset_statistics.log ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 2023-08-23 17:29:50,216:Loaded dataset from disk
2
+ 2023-08-23 17:43:38,040:Loaded dataset from disk
3
+ 2023-08-24 18:52:13,955:Loaded dataset from disk
4
+ 2023-09-26 05:46:30,524:Loaded dataset from disk
log_files/npmi.log ADDED
File without changes
log_files/zipf.log ADDED
File without changes
run.sh ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --label_field="label" --feature="text"
3
+ python3 run_data_measurements.py --dataset="hate_speech_offensive" --config="default" --split="train" --label_field="label" --feature="tweet"
4
+
5
+
6
+ python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text"
7
+ python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="unsupervised" --label_field="label" --feature="text"
8
+
9
+
10
+ python3 run_data_measurements.py --dataset="glue" --config="cola" --split="train" --label_field="label" --feature="sentence"
11
+ python3 run_data_measurements.py --dataset="glue" --config="cola" --split="validation" --label_field="label" --feature="sentence"
12
+
13
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="train" --label_field="label" --feature="hypothesis"
14
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="train" --label_field="label" --feature="premise"
15
+
16
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="validation_matched" --label_field="label" --feature="premise"
17
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="validation_matched" --label_field="label" --feature="hypothesis"
18
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="validation_mismatched" --label_field="label" --feature="premise"
19
+ python3 run_data_measurements.py --dataset="glue" --config="mnli" --split="validation_mismatched" --label_field="label" --feature="hypothesis"
20
+
21
+
22
+ python3 run_data_measurements.py --dataset="glue" --config="mrpc" --split="train" --label_field="label" --feature="sentence1"
23
+ python3 run_data_measurements.py --dataset="glue" --config="mrpc" --split="train" --label_field="label" --feature="sentence2"
24
+ python3 run_data_measurements.py --dataset="glue" --config="mrpc" --split="validation" --label_field="label" --feature="sentence1"
25
+ python3 run_data_measurements.py --dataset="glue" --config="mrpc" --split="validation" --label_field="label" --feature="sentence2"
26
+
27
+
28
+ python3 run_data_measurements.py --dataset="glue" --config="rte" --split="train" --label_field="label" --feature="sentence1"
29
+ python3 run_data_measurements.py --dataset="glue" --config="rte" --split="train" --label_field="label" --feature="sentence2"
30
+ python3 run_data_measurements.py --dataset="glue" --config="rte" --split="validation" --label_field="label" --feature="sentence1"
31
+ python3 run_data_measurements.py --dataset="glue" --config="rte" --split="validation" --label_field="label" --feature="sentence2"
32
+
33
+
34
+ python3 run_data_measurements.py --dataset="glue" --config="stsb" --split="train" --label_field="label" --feature="sentence1"
35
+ python3 run_data_measurements.py --dataset="glue" --config="stsb" --split="train" --label_field="label" --feature="sentence2"
36
+ python3 run_data_measurements.py --dataset="glue" --config="stsb" --split="validation" --label_field="label" --feature="sentence1"
37
+ python3 run_data_measurements.py --dataset="glue" --config="stsb" --split="validation" --label_field="label" --feature="sentence2"
38
+
39
+ python3 run_data_measurements.py --dataset="glue" --config="wnli" --split="train" --label_field="label" --feature="sentence1"
40
+ python3 run_data_measurements.py --dataset="glue" --config="wnli" --split="train" --label_field="label" --feature="sentence2"
41
+ python3 run_data_measurements.py --dataset="glue" --config="wnli" --split="validation" --label_field="label" --feature="sentence1"
42
+ python3 run_data_measurements.py --dataset="glue" --config="wnli" --split="validation" --label_field="label" --feature="sentence2"
43
+
44
+ python3 run_data_measurements.py --dataset="glue" --config="sst2" --split="train" --label_field="label" --feature="sentence"
45
+ python3 run_data_measurements.py --dataset="glue" --config="sst2" --split="validation" --label_field="label" --feature="sentence"
46
+
47
+
48
+ python3 run_data_measurements.py --dataset="glue" --config="qnli" --split="train" --label_field="label" --feature="question"
49
+ python3 run_data_measurements.py --dataset="glue" --config="qnli" --split="train" --label_field="label" --feature="sentence"
50
+ python3 run_data_measurements.py --dataset="glue" --config="qnli" --split="validation" --label_field="label" --feature="question"
51
+ python3 run_data_measurements.py --dataset="glue" --config="qnli" --split="validation" --label_field="label" --feature="sentence"
52
+
53
+
54
+ python3 run_data_measurements.py --dataset="glue" --config="qqp" --split="train" --label_field="label" --feature="question1"
55
+ python3 run_data_measurements.py --dataset="glue" --config="qqp" --split="train" --label_field="label" --feature="question2"
56
+ python3 run_data_measurements.py --dataset="glue" --config="qqp" --split="validation" --label_field="label" --feature="question1"
57
+ python3 run_data_measurements.py --dataset="glue" --config="qqp" --split="validation" --label_field="label" --feature="question2"
58
+
59
+ python3 run_data_measurements.py --dataset="glue" --config="mnli_matched" --split="validation" --label_field="label" --feature="hypothesis"
60
+ python3 run_data_measurements.py --dataset="glue" --config="mnli_matched" --split="validation" --label_field="label" --feature="premise"
61
+ python3 run_data_measurements.py --dataset="glue" --config="mnli_mismatched" --split="validation" --label_field="label" --feature="hypothesis"
62
+ python3 run_data_measurements.py --dataset="glue" --config="mnli_mismatched" --split="validation" --label_field="label" --feature="premise"
63
+
64
+
65
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-103-v1" --split="train" --feature="text"
66
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-103-raw-v1" --split="train" --feature="text"
67
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-2-v1" --split="train" --feature="text"
68
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-2-raw-v1" --split="train" --feature="text"
69
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-103-v1" --split="validation" --feature="text"
70
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-103-raw-v1" --split="validation" --feature="text"
71
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-2-v1" --split="validation" --feature="text"
72
+ python3 run_data_measurements.py --dataset="wikitext" --config="wikitext-2-raw-v1" --split="validation" --feature="text"
73
+
74
+
75
+ # Superglue wsc? wic? rte? record? multirc?
76
+
77
+ python3 run_data_measurements.py --dataset="super_glue" --config="boolq" --split="train" --label_field="label" --feature="question"
78
+ python3 run_data_measurements.py --dataset="super_glue" --config="boolq" --split="validation" --label_field="label" --feature="question"
79
+ python3 run_data_measurements.py --dataset="super_glue" --config="boolq" --split="train" --label_field="label" --feature="passage"
80
+ python3 run_data_measurements.py --dataset="super_glue" --config="boolq" --split="validation" --label_field="label" --feature="passage"
81
+
82
+ python3 run_data_measurements.py --dataset="super_glue" --config="cb" --split="train" --label_field="label" --feature="premise"
83
+ python3 run_data_measurements.py --dataset="super_glue" --config="cb" --split="validation" --label_field="label" --feature="premise"
84
+ python3 run_data_measurements.py --dataset="super_glue" --config="cb" --split="train" --label_field="label" --feature="hypothesis"
85
+ python3 run_data_measurements.py --dataset="super_glue" --config="cb" --split="validation" --label_field="label" --feature="hypothesis"
86
+
87
+
88
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="train" --label_field="label" --feature="premise"
89
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="validation" --label_field="label" --feature="premise"
90
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="train" --label_field="label" --feature="choice1"
91
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="validation" --label_field="label" --feature="choice1"
92
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="train" --label_field="label" --feature="choice2"
93
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="validation" --label_field="label" --feature="choice2"
94
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="train" --label_field="label" --feature="question"
95
+ python3 run_data_measurements.py --dataset="super_glue" --config="copa" --split="validation" --label_field="label" --feature="question"
96
+
97
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="train" --feature="context"
98
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="train" --feature="question"
99
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="train" --feature="title"
100
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="validation" --feature="context"
101
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="validation" --feature="question"
102
+ python3 run_data_measurements.py --dataset="squad" --config="plain_text" --split="validation" --feature="title"
103
+
104
+
105
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="train" --feature="context"
106
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="train" --feature="question"
107
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="train" --feature="title"
108
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="validation" --feature="context"
109
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="validation" --feature="question"
110
+ python3 run_data_measurements.py --dataset="squad_v2" --config="squad_v2" --split="validation" --feature="title"
run_data_measurements.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import textwrap
4
+ from os import mkdir
5
+ from os.path import join as pjoin, isdir
6
+
7
+ from data_measurements import dataset_statistics
8
+ from data_measurements import dataset_utils
9
+
10
+
11
+
12
+ def load_or_prepare_widgets(ds_args, show_embeddings=False, use_cache=False):
13
+ """
14
+ Loader specifically for the widgets used in the app -- does not compute
15
+ intermediate files, unless they are not there and are needed for a file
16
+ used in the UI.
17
+ Does not take specifications from user; does all widgets.
18
+ Args:
19
+ ds_args: Dataset configuration settings (config name, split, etc)
20
+ show_embeddings: Whether to compute embeddings (slow)
21
+ use_cache: Whether to grab files that have already been computed
22
+
23
+ Returns:
24
+ Saves files to disk in cache_dir, if user has not specified another dir.
25
+ """
26
+
27
+ if not isdir(ds_args["cache_dir"]):
28
+ print("Creating cache")
29
+ # We need to preprocess everything.
30
+ # This should eventually all go into a prepare_dataset CLI
31
+ mkdir(ds_args["cache_dir"])
32
+
33
+
34
+ dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args,
35
+ use_cache=use_cache)
36
+ # Embeddings widget
37
+ dstats.load_or_prepare_dataset()
38
+ # Header widget
39
+ dstats.load_or_prepare_dset_peek()
40
+ # General stats widget
41
+ dstats.load_or_prepare_general_stats()
42
+ # Labels widget
43
+ try:
44
+ dstats.set_label_field(ds_args['label_field'])
45
+ dstats.load_or_prepare_labels()
46
+ except:
47
+ pass
48
+ # Text lengths widget
49
+ dstats.load_or_prepare_text_lengths()
50
+ if show_embeddings:
51
+ # Embeddings widget
52
+ dstats.load_or_prepare_embeddings()
53
+ # Text duplicates widget
54
+ dstats.load_or_prepare_text_duplicates()
55
+ # nPMI widget
56
+ dstats.load_or_prepare_npmi()
57
+ npmi_stats = dstats.npmi_stats
58
+ # Handling for all pairs; in the UI, people select.
59
+ do_npmi(npmi_stats)
60
+ # Zipf widget
61
+ dstats.load_or_prepare_zipf()
62
+
63
+
64
+ def load_or_prepare(dataset_args, use_cache=False):
65
+ """
66
+ Users can specify which aspects of the dataset they would like to compute.
67
+ This additionally computes intermediate files not used in the UI.
68
+ If the calculation flag is not specified by the user (-w), calculates all
69
+ except for embeddings, as those are quite time consuming so should be
70
+ specified separately.
71
+ Args:
72
+ dataset_args: Dataset configuration settings (config name, split, etc)
73
+ use_cache: Whether to grab files that have already been computed
74
+
75
+ Returns:
76
+ Saves files to disk in cache_dir, if user has not specified another dir.
77
+ """
78
+ all = False
79
+ dstats = dataset_statistics.DatasetStatisticsCacheClass(**dataset_args,
80
+ use_cache=use_cache)
81
+ print("Loading dataset.")
82
+ dstats.load_or_prepare_dataset()
83
+ print("Dataset loaded. Preparing vocab.")
84
+ dstats.load_or_prepare_vocab()
85
+ print("Vocab prepared.")
86
+
87
+ if not dataset_args["calculation"]:
88
+ all = True
89
+
90
+ if all or dataset_args["calculation"] == "general":
91
+ print("\n* Calculating general statistics.")
92
+ dstats.load_or_prepare_general_stats()
93
+ print("Done!")
94
+ print("Basic text statistics now available at %s." %
95
+ dstats.general_stats_json_fid)
96
+ print(
97
+ "Text duplicates now available at %s." % dstats.dup_counts_df_fid
98
+ )
99
+
100
+ if all or dataset_args["calculation"] == "lengths":
101
+ print("\n* Calculating text lengths.")
102
+ dstats.load_or_prepare_text_lengths()
103
+ print("Done!")
104
+
105
+ if all or dataset_args["calculation"] == "labels":
106
+ if not dstats.label_field:
107
+ print("Warning: You asked for label calculation, but didn't "
108
+ "provide the labels field name. Assuming it is 'label'...")
109
+ dstats.set_label_field("label")
110
+ else:
111
+ print("\n* Calculating label distribution.")
112
+ dstats.load_or_prepare_labels()
113
+ fig_label_html = pjoin(dstats.cache_path, "labels_fig.html")
114
+ fig_label_json = pjoin(dstats.cache_path, "labels.json")
115
+ dstats.fig_labels.write_html(fig_label_html)
116
+ with open(fig_label_json, "w+") as f:
117
+ json.dump(dstats.fig_labels.to_json(), f)
118
+ print("Done!")
119
+ print("Label distribution now available at %s." %
120
+ dstats.label_dset_fid)
121
+ print("Figure saved to %s." % fig_label_html)
122
+
123
+ if all or dataset_args["calculation"] == "npmi":
124
+ print("\n* Preparing nPMI.")
125
+ npmi_stats = dataset_statistics.nPMIStatisticsCacheClass(
126
+ dstats, use_cache=use_cache
127
+ )
128
+ do_npmi(npmi_stats)
129
+ print("Done!")
130
+ print(
131
+ "nPMI results now available in %s for all identity terms that "
132
+ "occur more than 10 times and all words that "
133
+ "co-occur with both terms."
134
+ % npmi_stats.pmi_cache_path
135
+ )
136
+
137
+ if all or dataset_args["calculation"] == "zipf":
138
+ print("\n* Preparing Zipf.")
139
+ zipf_fig_fid = pjoin(dstats.cache_path, "zipf_fig.html")
140
+ zipf_json_fid = pjoin(dstats.cache_path, "zipf_fig.json")
141
+ dstats.load_or_prepare_zipf()
142
+ zipf_fig = dstats.zipf_fig
143
+ with open(zipf_json_fid, "w+") as f:
144
+ json.dump(zipf_fig.to_json(), f)
145
+ zipf_fig.write_html(zipf_fig_fid)
146
+ print("Done!")
147
+ print("Zipf results now available at %s." % dstats.zipf_fid)
148
+ print(
149
+ "Figure saved to %s, with corresponding json at %s."
150
+ % (zipf_fig_fid, zipf_json_fid)
151
+ )
152
+
153
+ # Don't do this one until someone specifically asks for it -- takes awhile.
154
+ if dataset_args["calculation"] == "embeddings":
155
+ print("\n* Preparing text embeddings.")
156
+ dstats.load_or_prepare_embeddings()
157
+
158
+
159
+ def do_npmi(npmi_stats):
160
+ available_terms = npmi_stats.load_or_prepare_npmi_terms()
161
+ completed_pairs = {}
162
+ print("Iterating through terms for joint npmi.")
163
+ for term1 in available_terms:
164
+ for term2 in available_terms:
165
+ if term1 != term2:
166
+ sorted_terms = tuple(sorted([term1, term2]))
167
+ if sorted_terms not in completed_pairs:
168
+ term1, term2 = sorted_terms
169
+ print("Computing nPMI statistics for %s and %s" % (term1, term2))
170
+ _ = npmi_stats.load_or_prepare_joint_npmi(sorted_terms)
171
+ completed_pairs[tuple(sorted_terms)] = {}
172
+
173
+
174
+ def get_text_label_df(
175
+ ds_name,
176
+ config_name,
177
+ split_name,
178
+ text_field,
179
+ label_field,
180
+ calculation,
181
+ out_dir,
182
+ use_cache=True,
183
+ ):
184
+ if not use_cache:
185
+ print("Not using any cache; starting afresh")
186
+ ds_name_to_dict = dataset_utils.get_dataset_info_dicts(ds_name)
187
+ if label_field:
188
+ label_field, label_names = (
189
+ ds_name_to_dict[ds_name][config_name]["features"][label_field][0]
190
+ if len(ds_name_to_dict[ds_name][config_name]["features"][label_field]) > 0
191
+ else ((), [])
192
+ )
193
+ else:
194
+ label_field = ()
195
+ label_names = []
196
+ dataset_args = {
197
+ "dset_name": ds_name,
198
+ "dset_config": config_name,
199
+ "split_name": split_name,
200
+ "text_field": text_field,
201
+ "label_field": label_field,
202
+ "label_names": label_names,
203
+ "calculation": calculation,
204
+ "cache_dir": out_dir,
205
+ }
206
+ load_or_prepare(dataset_args, use_cache=use_cache)
207
+
208
+
209
+ def main():
210
+ # TODO: Make this the Hugging Face arg parser
211
+ parser = argparse.ArgumentParser(
212
+ formatter_class=argparse.RawDescriptionHelpFormatter,
213
+ description=textwrap.dedent(
214
+ """
215
+
216
+ Example for hate speech18 dataset:
217
+ python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --feature="text"
218
+
219
+ Example for IMDB dataset:
220
+ python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text"
221
+ """
222
+ ),
223
+ )
224
+
225
+ parser.add_argument(
226
+ "-d", "--dataset", required=True, help="Name of dataset to prepare"
227
+ )
228
+ parser.add_argument(
229
+ "-c", "--config", required=True, help="Dataset configuration to prepare"
230
+ )
231
+ parser.add_argument(
232
+ "-s", "--split", required=True, type=str, help="Dataset split to prepare"
233
+ )
234
+ parser.add_argument(
235
+ "-f",
236
+ "--feature",
237
+ required=True,
238
+ type=str,
239
+ default="text",
240
+ help="Text column to prepare",
241
+ )
242
+ parser.add_argument(
243
+ "-w",
244
+ "--calculation",
245
+ help="""What to calculate (defaults to everything except embeddings).\n
246
+ Options are:\n
247
+
248
+ - `general` (for duplicate counts, missing values, length statistics.)\n
249
+
250
+ - `lengths` for text length distribution\n
251
+
252
+ - `labels` for label distribution\n
253
+
254
+ - `embeddings` (Warning: Slow.)\n
255
+
256
+ - `npmi` for word associations\n
257
+
258
+ - `zipf` for zipfian statistics
259
+ """,
260
+ )
261
+ parser.add_argument(
262
+ "-l",
263
+ "--label_field",
264
+ type=str,
265
+ required=False,
266
+ default="",
267
+ help="Field name for label column in dataset (Required if there is a label field that you want information about)",
268
+ )
269
+ parser.add_argument(
270
+ "--cached",
271
+ default=False,
272
+ required=False,
273
+ action="store_true",
274
+ help="Whether to use cached files (Optional)",
275
+ )
276
+ parser.add_argument(
277
+ "--do_html",
278
+ default=False,
279
+ required=False,
280
+ action="store_true",
281
+ help="Whether to write out corresponding HTML files (Optional)",
282
+ )
283
+ parser.add_argument("--out_dir", default="cache_dir", help="Where to write out to.")
284
+
285
+ args = parser.parse_args()
286
+ print("Proceeding with the following arguments:")
287
+ print(args)
288
+ # run_data_measurements.py -d hate_speech18 -c default -s train -f text -w npmi
289
+ get_text_label_df(args.dataset, args.config, args.split, args.feature,
290
+ args.label_field, args.calculation, args.out_dir,
291
+ use_cache=args.cached)
292
+ print()
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()
temp.jsonl ADDED
The diff for this file is too large to render. See raw diff