yourusername commited on
Commit
9b51db9
1 Parent(s): b9430ed

:tada: init

Browse files
data_measurements/__init__.py ADDED
File without changes
data_measurements/dataset_statistics.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import statistics
18
+ from os import mkdir
19
+ from os.path import exists, isdir
20
+ from os.path import join as pjoin
21
+
22
+ import nltk
23
+ import numpy as np
24
+ import pandas as pd
25
+ import plotly.express as px
26
+ import plotly.figure_factory as ff
27
+ import plotly.graph_objects as go
28
+ import pyarrow.feather as feather
29
+ from datasets import load_from_disk
30
+ from nltk.corpus import stopwords
31
+ from sklearn.feature_extraction.text import CountVectorizer
32
+
33
+ from .dataset_utils import (
34
+ CNT,
35
+ DEDUP_TOT,
36
+ EMBEDDING_FIELD,
37
+ LENGTH_FIELD,
38
+ OUR_LABEL_FIELD,
39
+ OUR_TEXT_FIELD,
40
+ PROP,
41
+ TEXT_NAN_CNT,
42
+ TOKENIZED_FIELD,
43
+ TXT_LEN,
44
+ VOCAB,
45
+ WORD,
46
+ extract_field,
47
+ load_truncated_dataset,
48
+ )
49
+ from .embeddings import Embeddings
50
+ from .npmi import nPMI
51
+ from .zipf import Zipf
52
+
53
+ pd.options.display.float_format = "{:,.3f}".format
54
+
55
+ logs = logging.getLogger(__name__)
56
+ logs.setLevel(logging.WARNING)
57
+ logs.propagate = False
58
+
59
+ if not logs.handlers:
60
+
61
+ # Logging info to log file
62
+ file = logging.FileHandler("./log_files/dataset_statistics.log")
63
+ fileformat = logging.Formatter("%(asctime)s:%(message)s")
64
+ file.setLevel(logging.INFO)
65
+ file.setFormatter(fileformat)
66
+
67
+ # Logging debug messages to stream
68
+ stream = logging.StreamHandler()
69
+ streamformat = logging.Formatter("[data_measurements_tool] %(message)s")
70
+ stream.setLevel(logging.WARNING)
71
+ stream.setFormatter(streamformat)
72
+
73
+ logs.addHandler(file)
74
+ logs.addHandler(stream)
75
+
76
+
77
+ # TODO: Read this in depending on chosen language / expand beyond english
78
+ nltk.download("stopwords")
79
+ _CLOSED_CLASS = (
80
+ stopwords.words("english")
81
+ + [
82
+ "t",
83
+ "n",
84
+ "ll",
85
+ "d",
86
+ "wasn",
87
+ "weren",
88
+ "won",
89
+ "aren",
90
+ "wouldn",
91
+ "shouldn",
92
+ "didn",
93
+ "don",
94
+ "hasn",
95
+ "ain",
96
+ "couldn",
97
+ "doesn",
98
+ "hadn",
99
+ "haven",
100
+ "isn",
101
+ "mightn",
102
+ "mustn",
103
+ "needn",
104
+ "shan",
105
+ "would",
106
+ "could",
107
+ "dont",
108
+ "u",
109
+ ]
110
+ + [str(i) for i in range(0, 21)]
111
+ )
112
+ _IDENTITY_TERMS = [
113
+ "man",
114
+ "woman",
115
+ "non-binary",
116
+ "gay",
117
+ "lesbian",
118
+ "queer",
119
+ "trans",
120
+ "straight",
121
+ "cis",
122
+ "she",
123
+ "her",
124
+ "hers",
125
+ "he",
126
+ "him",
127
+ "his",
128
+ "they",
129
+ "them",
130
+ "their",
131
+ "theirs",
132
+ "himself",
133
+ "herself",
134
+ ]
135
+ # treating inf values as NaN as well
136
+ pd.set_option("use_inf_as_na", True)
137
+
138
+ _MIN_VOCAB_COUNT = 10
139
+ _TREE_DEPTH = 12
140
+ _TREE_MIN_NODES = 250
141
+ # as long as we're using sklearn - already pushing the resources
142
+ _MAX_CLUSTER_EXAMPLES = 5000
143
+ _NUM_VOCAB_BATCHES = 2000
144
+
145
+
146
+ _CVEC = CountVectorizer(token_pattern="(?u)\\b\\w+\\b", lowercase=True)
147
+
148
+ num_rows = 200000
149
+
150
+
151
+ class DatasetStatisticsCacheClass:
152
+ def __init__(
153
+ self,
154
+ cache_dir,
155
+ dset_name,
156
+ dset_config,
157
+ split_name,
158
+ text_field,
159
+ label_field,
160
+ label_names,
161
+ calculation=None,
162
+ ):
163
+ # This is only used for standalone runs for each kind of measurement.
164
+ self.calculation = calculation
165
+ self.our_text_field = OUR_TEXT_FIELD
166
+ self.our_length_field = LENGTH_FIELD
167
+ self.our_label_field = OUR_LABEL_FIELD
168
+ self.our_tokenized_field = TOKENIZED_FIELD
169
+ self.our_embedding_field = EMBEDDING_FIELD
170
+ self.cache_dir = cache_dir
171
+ ### What are we analyzing?
172
+ # name of the Hugging Face dataset
173
+ self.dset_name = dset_name
174
+ # name of the dataset config
175
+ self.dset_config = dset_config
176
+ # name of the split to analyze
177
+ self.split_name = split_name
178
+ # which text fields are we analysing?
179
+ self.text_field = text_field
180
+ # which label fields are we analysing?
181
+ self.label_field = label_field
182
+ # what are the names of the classes?
183
+ self.label_names = label_names
184
+ ## Hugging Face dataset objects
185
+ self.dset = None # original dataset
186
+ # HF dataset with all of the self.text_field instances in self.dset
187
+ self.text_dset = None
188
+ # HF dataset with text embeddings in the same order as self.text_dset
189
+ self.embeddings_dset = None
190
+ # HF dataset with all of the self.label_field instances in self.dset
191
+ self.label_dset = None
192
+ ## Data frames
193
+ # Tokenized text
194
+ self.tokenized_df = []
195
+ # save sentence length histogram in the class so it doesn't ge re-computed
196
+ self.fig_tok_length = None
197
+ # Data Frame version of self.label_dset
198
+ self.label_df = None
199
+ # save label pie chart in the class so it doesn't ge re-computed
200
+ self.fig_labels = None
201
+ # Vocabulary with word counts in the dataset
202
+ self.vocab_counts_df = None
203
+ # Vocabulary filtered to remove stopwords
204
+ self.vocab_counts_filtered_df = None
205
+ ## General statistics and duplicates
206
+ # Number of NaN values (NOT empty strings)
207
+ self.text_nan_count = 0
208
+ # Number of text items that appear more than once in the dataset
209
+ self.dedup_total = 0
210
+ # Duplicated text items along with their number of occurences ("count")
211
+ self.text_dup_counts_df = None
212
+ self.avg_length = None
213
+ self.std_length = None
214
+ self.general_stats_dict = None
215
+ # clustering text by embeddings
216
+ # the hierarchical clustering tree is represented as a list of nodes,
217
+ # the first is the root
218
+ self.node_list = []
219
+ # save tree figure in the class so it doesn't ge re-computed
220
+ self.fig_tree = None
221
+ # keep Embeddings object around to explore clusters
222
+ self.embeddings = None
223
+ # nPMI
224
+ # Holds a nPMIStatisticsCacheClass object
225
+ self.npmi_stats = None
226
+ # TODO: Users ideally can type in whatever words they want.
227
+ self.termlist = _IDENTITY_TERMS
228
+ # termlist terms that are available more than _MIN_VOCAB_COUNT times
229
+ self.available_terms = _IDENTITY_TERMS
230
+ # TODO: Have lowercase be an option for a user to set.
231
+ self.to_lowercase = True
232
+ # The minimum amount of times a word should occur to be included in
233
+ # word-count-based calculations (currently just relevant to nPMI)
234
+ self.min_vocab_count = _MIN_VOCAB_COUNT
235
+ # zipf
236
+ self.z = None
237
+ self.zipf_fig = None
238
+ self.cvec = _CVEC
239
+ # File definitions
240
+ # path to the directory used for caching
241
+ if not isinstance(text_field, str):
242
+ text_field = "-".join(text_field)
243
+ if isinstance(label_field, str):
244
+ label_field = label_field
245
+ else:
246
+ label_field = "-".join(label_field)
247
+ self.cache_path = pjoin(
248
+ self.cache_dir,
249
+ f"{dset_name}_{dset_config}_{split_name}_{text_field}_{label_field}",
250
+ )
251
+ if not isdir(self.cache_path):
252
+ logs.warning("Creating cache directory %s." % self.cache_path)
253
+ mkdir(self.cache_path)
254
+ self.dset_fid = pjoin(self.cache_path, "base_dset")
255
+ self.text_dset_fid = pjoin(self.cache_path, "text_dset")
256
+ self.tokenized_df_fid = pjoin(self.cache_path, "tokenized_df.feather")
257
+ self.label_dset_fid = pjoin(self.cache_path, "label_dset")
258
+ self.vocab_counts_df_fid = pjoin(self.cache_path, "vocab_counts.feather")
259
+ self.general_stats_fid = pjoin(self.cache_path, "general_stats.json")
260
+ self.text_duplicate_counts_df_fid = pjoin(
261
+ self.cache_path, "text_dup_counts_df.feather"
262
+ )
263
+ self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
264
+
265
+ def get_base_dataset(self):
266
+ """Gets a pointer to the truncated base dataset object."""
267
+ if not self.dset:
268
+ self.dset = load_truncated_dataset(
269
+ self.dset_name,
270
+ self.dset_config,
271
+ self.split_name,
272
+ cache_name=self.dset_fid,
273
+ use_cache=True,
274
+ use_streaming=True,
275
+ )
276
+
277
+ def get_dataset_peek(self):
278
+ self.get_base_dataset()
279
+ return self.dset[:100]
280
+
281
+ def load_or_prepare_general_stats(self, use_cache=False):
282
+ """Data structures used in calculating general statistics and duplicates"""
283
+
284
+ # TODO: These probably don't need to be feather files, could be csv.
285
+ # General statistics
286
+ if (
287
+ use_cache
288
+ and exists(self.general_stats_fid)
289
+ and exists(self.text_duplicate_counts_df_fid)
290
+ ):
291
+ self.load_general_stats(
292
+ self.general_stats_fid, self.text_duplicate_counts_df_fid
293
+ )
294
+ else:
295
+ (
296
+ self.text_nan_count,
297
+ self.dedup_total,
298
+ self.text_dup_counts_df,
299
+ ) = self.prepare_general_text_stats()
300
+ self.general_stats_dict = {
301
+ TEXT_NAN_CNT: self.text_nan_count,
302
+ DEDUP_TOT: self.dedup_total,
303
+ }
304
+ write_df(self.text_dup_counts_df, self.text_duplicate_counts_df_fid)
305
+ write_json(self.general_stats_dict, self.general_stats_fid)
306
+
307
+ def load_or_prepare_text_lengths(self, use_cache=False):
308
+ if len(self.tokenized_df) == 0:
309
+ self.tokenized_df = self.do_tokenization()
310
+ self.tokenized_df[LENGTH_FIELD] = self.tokenized_df[TOKENIZED_FIELD].apply(len)
311
+ self.avg_length = round(
312
+ sum(self.tokenized_df[self.our_length_field])
313
+ / len(self.tokenized_df[self.our_length_field]),
314
+ 1,
315
+ )
316
+ self.std_length = round(
317
+ statistics.stdev(self.tokenized_df[self.our_length_field]), 1
318
+ )
319
+ self.fig_tok_length = make_fig_lengths(self.tokenized_df, self.our_length_field)
320
+
321
+ def load_or_prepare_embeddings(self, use_cache=False):
322
+ self.embeddings = Embeddings(self, use_cache=use_cache)
323
+ self.embeddings.make_hierarchical_clustering()
324
+ self.fig_tree = self.embeddings.fig_tree
325
+ self.node_list = self.embeddings.node_list
326
+
327
+ # get vocab with word counts
328
+ def load_or_prepare_vocab(self, use_cache=True, save=True):
329
+ """
330
+ Calculates the vocabulary count from the tokenized text.
331
+ The resulting dataframes may be used in nPMI calculations, zipf, etc.
332
+ :param use_cache:
333
+ :return:
334
+ """
335
+ if (
336
+ use_cache
337
+ and exists(self.vocab_counts_df_fid)
338
+ ):
339
+ logs.info("Reading vocab from cache")
340
+ self.load_vocab()
341
+ self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
342
+ else:
343
+ logs.info("Calculating vocab afresh")
344
+ if len(self.tokenized_df) == 0:
345
+ self.tokenized_df = self.do_tokenization()
346
+ if save:
347
+ logs.info("Writing out.")
348
+ write_df(self.tokenized_df, self.tokenized_df_fid)
349
+ word_count_df = count_vocab_frequencies(self.tokenized_df)
350
+ logs.info("Making dfs with proportion.")
351
+ self.vocab_counts_df = calc_p_word(word_count_df)
352
+ self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
353
+ if save:
354
+ logs.info("Writing out.")
355
+ write_df(self.vocab_counts_df, self.vocab_counts_df_fid)
356
+ logs.info("unfiltered vocab")
357
+ logs.info(self.vocab_counts_df)
358
+ logs.info("filtered vocab")
359
+ logs.info(self.vocab_counts_filtered_df)
360
+
361
+ def load_or_prepare_npmi_terms(self, use_cache=False):
362
+ self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=use_cache)
363
+ self.npmi_stats.load_or_prepare_npmi_terms()
364
+
365
+ def load_or_prepare_zipf(self, use_cache=False):
366
+ if use_cache and exists(self.zipf_fid):
367
+ # TODO: Read zipf data so that the vocab is there.
368
+ with open(self.zipf_fid, "r") as f:
369
+ zipf_dict = json.load(f)
370
+ self.z = Zipf()
371
+ self.z.load(zipf_dict)
372
+ else:
373
+ self.z = Zipf(self.vocab_counts_df)
374
+ write_zipf_data(self.z, self.zipf_fid)
375
+ self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
376
+
377
+ def prepare_general_text_stats(self):
378
+ text_nan_count = int(self.tokenized_df.isnull().sum().sum())
379
+ dup_df = self.tokenized_df[self.tokenized_df.duplicated([self.our_text_field])]
380
+ dedup_df = pd.DataFrame(
381
+ dup_df.pivot_table(
382
+ columns=[self.our_text_field], aggfunc="size"
383
+ ).sort_values(ascending=False),
384
+ columns=[CNT],
385
+ )
386
+ dedup_df.index = dedup_df.index.map(str)
387
+ dedup_df[OUR_TEXT_FIELD] = dedup_df.index
388
+ dedup_total = sum(dedup_df[CNT])
389
+ return text_nan_count, dedup_total, dedup_df
390
+
391
+ def load_general_stats(self, general_stats_fid, text_duplicate_counts_df_fid):
392
+ general_stats = json.load(open(general_stats_fid, encoding="utf-8"))
393
+ self.text_nan_count = general_stats[TEXT_NAN_CNT]
394
+ self.dedup_total = general_stats[DEDUP_TOT]
395
+ with open(text_duplicate_counts_df_fid, "rb") as f:
396
+ self.text_dup_counts_df = feather.read_feather(f)
397
+
398
+ def load_or_prepare_dataset(self, use_cache=True, use_df=False, save=True):
399
+ """
400
+ Prepares the HF datasets and data frames containing the untokenized and tokenized
401
+ text as well as the label values. If cache is not being used (use_cache=False), writes the datasets to text.
402
+ :param use_cache:
403
+ :param use_df: Whether to used stored dataframes rather than dset files
404
+ :return:
405
+ """
406
+ ## Raw text first, then tokenization.
407
+ # Use what has been previously stored in DataFrame form or Dataset form.
408
+ if (
409
+ use_cache
410
+ and use_df
411
+ and exists(self.tokenized_df_fid)
412
+ ):
413
+ self.tokenized_df = feather.read_feather(self.tokenized_df_fid)
414
+ elif (
415
+ use_cache and exists(self.text_dset_fid)):
416
+ # load extracted text
417
+ self.text_dset = load_from_disk(self.text_dset_fid)
418
+ logs.warning("Loaded dataset from disk")
419
+ logs.info(self.text_dset)
420
+ # ...Or load it from the server and store it anew
421
+ else:
422
+ self.get_base_dataset()
423
+ # extract all text instances
424
+ self.text_dset = self.dset.map(
425
+ lambda examples: extract_field(
426
+ examples, self.text_field, OUR_TEXT_FIELD
427
+ ),
428
+ batched=True,
429
+ remove_columns=list(self.dset.features),
430
+ )
431
+ if save:
432
+ # save extracted text instances
433
+ logs.warning("Saving dataset to disk")
434
+ self.text_dset.save_to_disk(self.text_dset_fid)
435
+ # tokenize all text instances
436
+ self.tokenized_df = self.do_tokenization()
437
+ if save:
438
+ # save tokenized text
439
+ write_df(self.tokenized_df, self.tokenized_df_fid)
440
+
441
+ def do_tokenization(self):
442
+ """
443
+ Tokenizes the dataset
444
+ :return:
445
+ """
446
+ sent_tokenizer = self.cvec.build_tokenizer()
447
+
448
+ def tokenize_batch(examples):
449
+ # TODO: lowercase should be an option
450
+ res = {
451
+ TOKENIZED_FIELD: [
452
+ tuple(sent_tokenizer(text.lower()))
453
+ for text in examples[OUR_TEXT_FIELD]
454
+ ]
455
+ }
456
+ res[LENGTH_FIELD] = [len(tok_text) for tok_text in res[TOKENIZED_FIELD]]
457
+ return res
458
+
459
+ tokenized_dset = self.text_dset.map(
460
+ tokenize_batch,
461
+ batched=True,
462
+ # remove_columns=[OUR_TEXT_FIELD], keep around to print
463
+ )
464
+ tokenized_df = pd.DataFrame(tokenized_dset)
465
+ return tokenized_df
466
+
467
+ def set_label_field(self, label_field="label"):
468
+ """
469
+ Setter for label_field. Used in the CLI when a user asks for information
470
+ about labels, but does not specify the field;
471
+ 'label' is assumed as a default.
472
+ """
473
+ self.label_field = label_field
474
+
475
+ def load_or_prepare_labels(self, use_cache=False, save=True):
476
+ """
477
+ Extracts labels from the Dataset
478
+ :param use_cache:
479
+ :return:
480
+ """
481
+ # extracted labels
482
+ if len(self.label_field) > 0:
483
+ if use_cache and exists(self.label_dset_fid):
484
+ # load extracted labels
485
+ self.label_dset = load_from_disk(self.label_dset_fid)
486
+ else:
487
+ self.get_base_dataset()
488
+ self.label_dset = self.dset.map(
489
+ lambda examples: extract_field(
490
+ examples, self.label_field, OUR_LABEL_FIELD
491
+ ),
492
+ batched=True,
493
+ remove_columns=list(self.dset.features),
494
+ )
495
+ if save:
496
+ # save extracted label instances
497
+ self.label_dset.save_to_disk(self.label_dset_fid)
498
+ self.label_df = self.label_dset.to_pandas()
499
+
500
+ self.fig_labels = make_fig_labels(
501
+ self.label_df, self.label_names, OUR_LABEL_FIELD
502
+ )
503
+
504
+ def load_vocab(self):
505
+ with open(self.vocab_counts_df_fid, "rb") as f:
506
+ self.vocab_counts_df = feather.read_feather(f)
507
+ # Handling for changes in how the index is saved.
508
+ self.vocab_counts_df = self._set_idx_col_names(self.vocab_counts_df)
509
+
510
+ def _set_idx_col_names(self, input_vocab_df):
511
+ if input_vocab_df.index.name != VOCAB and VOCAB in input_vocab_df.columns:
512
+ input_vocab_df = input_vocab_df.set_index([VOCAB])
513
+ input_vocab_df[VOCAB] = input_vocab_df.index
514
+ return input_vocab_df
515
+
516
+
517
+ class nPMIStatisticsCacheClass:
518
+ """ "Class to interface between the app and the nPMI class
519
+ by calling the nPMI class with the user's selections."""
520
+
521
+ def __init__(self, dataset_stats, use_cache=False):
522
+ self.dstats = dataset_stats
523
+ self.pmi_cache_path = pjoin(self.dstats.cache_path, "pmi_files")
524
+ if not isdir(self.pmi_cache_path):
525
+ logs.warning("Creating pmi cache directory %s." % self.pmi_cache_path)
526
+ # We need to preprocess everything.
527
+ mkdir(self.pmi_cache_path)
528
+ self.joint_npmi_df_dict = {}
529
+ self.termlist = self.dstats.termlist
530
+ logs.info(self.termlist)
531
+ self.use_cache = use_cache
532
+ # TODO: Let users specify
533
+ self.open_class_only = True
534
+ self.min_vocab_count = self.dstats.min_vocab_count
535
+ self.subgroup_files = {}
536
+ self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
537
+ self.available_terms = self.dstats.available_terms
538
+ logs.info(self.available_terms)
539
+
540
+ def load_or_prepare_npmi_terms(self, use_cache=False):
541
+ """
542
+ Figures out what identity terms the user can select, based on whether
543
+ they occur more than self.min_vocab_count times
544
+ :param use_cache:
545
+ :return: Identity terms occurring at least self.min_vocab_count times.
546
+ """
547
+ # TODO: Add the user's ability to select subgroups.
548
+ # TODO: Make min_vocab_count here value selectable by the user.
549
+ if (
550
+ use_cache
551
+ and exists(self.npmi_terms_fid)
552
+ and json.load(open(self.npmi_terms_fid))["available terms"] != []
553
+ ):
554
+ available_terms = json.load(open(self.npmi_terms_fid))["available terms"]
555
+ else:
556
+ true_false = [
557
+ term in self.dstats.vocab_counts_df.index for term in self.termlist
558
+ ]
559
+ word_list_tmp = [x for x, y in zip(self.termlist, true_false) if y]
560
+ true_false_counts = [
561
+ self.dstats.vocab_counts_df.loc[word, CNT] >= self.min_vocab_count
562
+ for word in word_list_tmp
563
+ ]
564
+ available_terms = [
565
+ word for word, y in zip(word_list_tmp, true_false_counts) if y
566
+ ]
567
+ logs.info(available_terms)
568
+ with open(self.npmi_terms_fid, "w+") as f:
569
+ json.dump({"available terms": available_terms}, f)
570
+ self.available_terms = available_terms
571
+ return available_terms
572
+
573
+ def load_or_prepare_joint_npmi(self, subgroup_pair, use_cache=True):
574
+ """
575
+ Run on-the fly, while the app is already open,
576
+ as it depends on the subgroup terms that the user chooses
577
+ :param subgroup_pair:
578
+ :return:
579
+ """
580
+ # Canonical ordering for subgroup_list
581
+ subgroup_pair = sorted(subgroup_pair)
582
+ subgroups_str = "-".join(subgroup_pair)
583
+ if not isdir(self.pmi_cache_path):
584
+ logs.warning("Creating cache")
585
+ # We need to preprocess everything.
586
+ # This should eventually all go into a prepare_dataset CLI
587
+ mkdir(self.pmi_cache_path)
588
+ joint_npmi_fid = pjoin(self.pmi_cache_path, subgroups_str + "_npmi.csv")
589
+ subgroup_files = define_subgroup_files(subgroup_pair, self.pmi_cache_path)
590
+ # Defines the filenames for the cache files from the selected subgroups.
591
+ # Get as much precomputed data as we can.
592
+ if use_cache and exists(joint_npmi_fid):
593
+ # When everything is already computed for the selected subgroups.
594
+ logs.info("Loading cached joint npmi")
595
+ joint_npmi_df = self.load_joint_npmi_df(joint_npmi_fid)
596
+ # When maybe some things have been computed for the selected subgroups.
597
+ else:
598
+ logs.info("Preparing new joint npmi")
599
+ joint_npmi_df, subgroup_dict = self.prepare_joint_npmi_df(
600
+ subgroup_pair, subgroup_files
601
+ )
602
+ # Cache new results
603
+ logs.info("Writing out.")
604
+ for subgroup in subgroup_pair:
605
+ write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files)
606
+ with open(joint_npmi_fid, "w+") as f:
607
+ joint_npmi_df.to_csv(f)
608
+ logs.info("The joint npmi df is")
609
+ logs.info(joint_npmi_df)
610
+ return joint_npmi_df
611
+
612
+ def load_joint_npmi_df(self, joint_npmi_fid):
613
+ """
614
+ Reads in a saved dataframe with all of the paired results.
615
+ :param joint_npmi_fid:
616
+ :return: paired results
617
+ """
618
+ with open(joint_npmi_fid, "rb") as f:
619
+ joint_npmi_df = pd.read_csv(f)
620
+ joint_npmi_df = self._set_idx_cols_from_cache(joint_npmi_df)
621
+ return joint_npmi_df.dropna()
622
+
623
+ def prepare_joint_npmi_df(self, subgroup_pair, subgroup_files):
624
+ """
625
+ Computs the npmi bias based on the given subgroups.
626
+ Handles cases where some of the selected subgroups have cached nPMI
627
+ computations, but other's don't, computing everything afresh if there
628
+ are not cached files.
629
+ :param subgroup_pair:
630
+ :return: Dataframe with nPMI for the words, nPMI bias between the words.
631
+ """
632
+ subgroup_dict = {}
633
+ # When npmi is computed for some (but not all) of subgroup_list
634
+ for subgroup in subgroup_pair:
635
+ logs.info("Load or failing...")
636
+ # When subgroup npmi has been computed in a prior session.
637
+ cached_results = self.load_or_fail_cached_npmi_scores(
638
+ subgroup, subgroup_files[subgroup]
639
+ )
640
+ # If the function did not return False and we did find it, use.
641
+ if cached_results:
642
+ # FYI: subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df = cached_results
643
+ # Holds the previous sessions' data for use in this session.
644
+ subgroup_dict[subgroup] = cached_results
645
+ logs.info("Calculating for subgroup list")
646
+ joint_npmi_df, subgroup_dict = self.do_npmi(subgroup_pair, subgroup_dict)
647
+ return joint_npmi_df.dropna(), subgroup_dict
648
+
649
+ # TODO: Update pairwise assumption
650
+ def do_npmi(self, subgroup_pair, subgroup_dict):
651
+ """
652
+ Calculates nPMI for given identity terms and the nPMI bias between.
653
+ :param subgroup_pair: List of identity terms to calculate the bias for
654
+ :return: Subset of data for the UI
655
+ :return: Selected identity term's co-occurrence counts with
656
+ other words, pmi per word, and nPMI per word.
657
+ """
658
+ logs.info("Initializing npmi class")
659
+ npmi_obj = self.set_npmi_obj()
660
+ # Canonical ordering used
661
+ subgroup_pair = tuple(sorted(subgroup_pair))
662
+ # Calculating nPMI statistics
663
+ for subgroup in subgroup_pair:
664
+ # If the subgroup data is already computed, grab it.
665
+ # TODO: Should we set idx and column names similarly to how we set them for cached files?
666
+ if subgroup not in subgroup_dict:
667
+ logs.info("Calculating statistics for %s" % subgroup)
668
+ vocab_cooc_df, pmi_df, npmi_df = npmi_obj.calc_metrics(subgroup)
669
+ # Store the nPMI information for the current subgroups
670
+ subgroup_dict[subgroup] = (vocab_cooc_df, pmi_df, npmi_df)
671
+ # Pair the subgroups together, indexed by all words that
672
+ # co-occur between them.
673
+ logs.info("Computing pairwise npmi bias")
674
+ paired_results = npmi_obj.calc_paired_metrics(subgroup_pair, subgroup_dict)
675
+ UI_results = make_npmi_fig(paired_results, subgroup_pair)
676
+ return UI_results, subgroup_dict
677
+
678
+ def set_npmi_obj(self):
679
+ """
680
+ Initializes the nPMI class with the given words and tokenized sentences.
681
+ :return:
682
+ """
683
+ npmi_obj = nPMI(self.dstats.vocab_counts_df, self.dstats.tokenized_df)
684
+ return npmi_obj
685
+
686
+ def load_or_fail_cached_npmi_scores(self, subgroup, subgroup_fids):
687
+ """
688
+ Reads cached scores from the specified subgroup files
689
+ :param subgroup: string of the selected identity term
690
+ :return:
691
+ """
692
+ # TODO: Ordering of npmi, pmi, vocab triple should be consistent
693
+ subgroup_npmi_fid, subgroup_pmi_fid, subgroup_cooc_fid = subgroup_fids
694
+ if (
695
+ exists(subgroup_npmi_fid)
696
+ and exists(subgroup_pmi_fid)
697
+ and exists(subgroup_cooc_fid)
698
+ ):
699
+ logs.info("Reading in pmi data....")
700
+ with open(subgroup_cooc_fid, "rb") as f:
701
+ subgroup_cooc_df = pd.read_csv(f)
702
+ logs.info("pmi")
703
+ with open(subgroup_pmi_fid, "rb") as f:
704
+ subgroup_pmi_df = pd.read_csv(f)
705
+ logs.info("npmi")
706
+ with open(subgroup_npmi_fid, "rb") as f:
707
+ subgroup_npmi_df = pd.read_csv(f)
708
+ subgroup_cooc_df = self._set_idx_cols_from_cache(
709
+ subgroup_cooc_df, subgroup, "count"
710
+ )
711
+ subgroup_pmi_df = self._set_idx_cols_from_cache(
712
+ subgroup_pmi_df, subgroup, "pmi"
713
+ )
714
+ subgroup_npmi_df = self._set_idx_cols_from_cache(
715
+ subgroup_npmi_df, subgroup, "npmi"
716
+ )
717
+ return subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df
718
+ return False
719
+
720
+ def _set_idx_cols_from_cache(self, csv_df, subgroup=None, calc_str=None):
721
+ """
722
+ Helps make sure all of the read-in files can be accessed within code
723
+ via standardized indices and column names.
724
+ :param csv_df:
725
+ :param subgroup:
726
+ :param calc_str:
727
+ :return:
728
+ """
729
+ # The csv saves with this column instead of the index, so that's weird.
730
+ if "Unnamed: 0" in csv_df.columns:
731
+ csv_df = csv_df.set_index("Unnamed: 0")
732
+ csv_df.index.name = WORD
733
+ elif WORD in csv_df.columns:
734
+ csv_df = csv_df.set_index(WORD)
735
+ csv_df.index.name = WORD
736
+ elif VOCAB in csv_df.columns:
737
+ csv_df = csv_df.set_index(VOCAB)
738
+ csv_df.index.name = WORD
739
+ if subgroup and calc_str:
740
+ csv_df.columns = [subgroup + "-" + calc_str]
741
+ elif subgroup:
742
+ csv_df.columns = [subgroup]
743
+ elif calc_str:
744
+ csv_df.columns = [calc_str]
745
+ return csv_df
746
+
747
+ def get_available_terms(self, use_cache=False):
748
+ return self.load_or_prepare_npmi_terms(use_cache=use_cache)
749
+
750
+ def dummy(doc):
751
+ return doc
752
+
753
+ def count_vocab_frequencies(tokenized_df):
754
+ """
755
+ Based on an input pandas DataFrame with a 'text' column,
756
+ this function will count the occurrences of all words.
757
+ :return: [num_words x num_sentences] DataFrame with the rows corresponding to the
758
+ different vocabulary words and the column to the presence (0 or 1) of that word.
759
+ """
760
+
761
+ cvec = CountVectorizer(
762
+ tokenizer=dummy,
763
+ preprocessor=dummy,
764
+ )
765
+ # We do this to calculate per-word statistics
766
+ # Fast calculation of single word counts
767
+ logs.info("Fitting dummy tokenization to make matrix using the previous tokenization")
768
+ cvec.fit(tokenized_df[TOKENIZED_FIELD])
769
+ document_matrix = cvec.transform(tokenized_df[TOKENIZED_FIELD])
770
+ batches = np.linspace(0, tokenized_df.shape[0], _NUM_VOCAB_BATCHES).astype(int)
771
+ i = 0
772
+ tf = []
773
+ while i < len(batches) - 1:
774
+ logs.info("%s of %s vocab batches" % (str(i), str(len(batches))))
775
+ batch_result = np.sum(
776
+ document_matrix[batches[i] : batches[i + 1]].toarray(), axis=0
777
+ )
778
+ tf.append(batch_result)
779
+ i += 1
780
+ word_count_df = pd.DataFrame(
781
+ [np.sum(tf, axis=0)], columns=cvec.get_feature_names()
782
+ ).transpose()
783
+ # Now organize everything into the dataframes
784
+ word_count_df.columns = [CNT]
785
+ word_count_df.index.name = WORD
786
+ return word_count_df
787
+
788
+ def calc_p_word(word_count_df):
789
+ # p(word)
790
+ word_count_df[PROP] = word_count_df[CNT] / float(sum(word_count_df[CNT]))
791
+ vocab_counts_df = pd.DataFrame(word_count_df.sort_values(by=CNT, ascending=False))
792
+ vocab_counts_df[VOCAB] = vocab_counts_df.index
793
+ return vocab_counts_df
794
+
795
+
796
+ def filter_words(vocab_counts_df):
797
+ # TODO: Add warnings (which words are missing) to log file?
798
+ filtered_vocab_counts_df = vocab_counts_df.drop(_CLOSED_CLASS,
799
+ errors="ignore")
800
+ filtered_count = filtered_vocab_counts_df[CNT]
801
+ filtered_count_denom = float(sum(filtered_vocab_counts_df[CNT]))
802
+ filtered_vocab_counts_df[PROP] = filtered_count / filtered_count_denom
803
+ return filtered_vocab_counts_df
804
+
805
+
806
+ ## Figures ##
807
+
808
+
809
+ def make_fig_lengths(tokenized_df, length_field):
810
+ fig_tok_length = px.histogram(
811
+ tokenized_df, x=length_field, marginal="rug", hover_data=[length_field]
812
+ )
813
+ return fig_tok_length
814
+
815
+
816
+ def make_fig_labels(label_df, label_names, label_field):
817
+ labels = label_df[label_field].unique()
818
+ label_sums = [len(label_df[label_df[label_field] == label]) for label in labels]
819
+ fig_labels = px.pie(label_df, values=label_sums, names=label_names)
820
+ return fig_labels
821
+
822
+
823
+ def make_zipf_fig_ranked_word_list(vocab_df, unique_counts, unique_ranks):
824
+ ranked_words = {}
825
+ for count, rank in zip(unique_counts, unique_ranks):
826
+ vocab_df[vocab_df[CNT] == count]["rank"] = rank
827
+ ranked_words[rank] = ",".join(
828
+ vocab_df[vocab_df[CNT] == count].index.astype(str)
829
+ ) # Use the hovertext kw argument for hover text
830
+ ranked_words_list = [wrds for rank, wrds in sorted(ranked_words.items())]
831
+ return ranked_words_list
832
+
833
+
834
+ def make_npmi_fig(paired_results, subgroup_pair):
835
+ subgroup1, subgroup2 = subgroup_pair
836
+ UI_results = pd.DataFrame()
837
+ if "npmi-bias" in paired_results:
838
+ UI_results["npmi-bias"] = paired_results["npmi-bias"].astype(float)
839
+ UI_results[subgroup1 + "-npmi"] = paired_results["npmi"][
840
+ subgroup1 + "-npmi"
841
+ ].astype(float)
842
+ UI_results[subgroup1 + "-count"] = paired_results["count"][
843
+ subgroup1 + "-count"
844
+ ].astype(int)
845
+ if subgroup1 != subgroup2:
846
+ UI_results[subgroup2 + "-npmi"] = paired_results["npmi"][
847
+ subgroup2 + "-npmi"
848
+ ].astype(float)
849
+ UI_results[subgroup2 + "-count"] = paired_results["count"][
850
+ subgroup2 + "-count"
851
+ ].astype(int)
852
+ return UI_results.sort_values(by="npmi-bias", ascending=True)
853
+
854
+
855
+ def make_zipf_fig(vocab_counts_df, z):
856
+ zipf_counts = z.calc_zipf_counts(vocab_counts_df)
857
+ unique_counts = z.uniq_counts
858
+ unique_ranks = z.uniq_ranks
859
+ ranked_words_list = make_zipf_fig_ranked_word_list(
860
+ vocab_counts_df, unique_counts, unique_ranks
861
+ )
862
+ zmin = z.get_xmin()
863
+ logs.info("zipf counts is")
864
+ logs.info(zipf_counts)
865
+ layout = go.Layout(xaxis=dict(range=[0, 100]))
866
+ fig = go.Figure(
867
+ data=[
868
+ go.Bar(
869
+ x=z.uniq_ranks,
870
+ y=z.uniq_counts,
871
+ hovertext=ranked_words_list,
872
+ name="Word Rank Frequency",
873
+ )
874
+ ],
875
+ layout=layout,
876
+ )
877
+ fig.add_trace(
878
+ go.Scatter(
879
+ x=z.uniq_ranks[zmin : len(z.uniq_ranks)],
880
+ y=zipf_counts[zmin : len(z.uniq_ranks)],
881
+ hovertext=ranked_words_list[zmin : len(z.uniq_ranks)],
882
+ line=go.scatter.Line(color="crimson", width=3),
883
+ name="Zipf Predicted Frequency",
884
+ )
885
+ )
886
+ # Customize aspect
887
+ # fig.update_traces(marker_color='limegreen',
888
+ # marker_line_width=1.5, opacity=0.6)
889
+ fig.update_layout(title_text="Word Counts, Observed and Predicted by Zipf")
890
+ fig.update_layout(xaxis_title="Word Rank")
891
+ fig.update_layout(yaxis_title="Frequency")
892
+ fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.10))
893
+ return fig
894
+
895
+
896
+ ## Input/Output ###
897
+
898
+
899
+ def define_subgroup_files(subgroup_list, pmi_cache_path):
900
+ """
901
+ Sets the file ids for the input identity terms
902
+ :param subgroup_list: List of identity terms
903
+ :return:
904
+ """
905
+ subgroup_files = {}
906
+ for subgroup in subgroup_list:
907
+ # TODO: Should the pmi, npmi, and count just be one file?
908
+ subgroup_npmi_fid = pjoin(pmi_cache_path, subgroup + "_npmi.csv")
909
+ subgroup_pmi_fid = pjoin(pmi_cache_path, subgroup + "_pmi.csv")
910
+ subgroup_cooc_fid = pjoin(pmi_cache_path, subgroup + "_vocab_cooc.csv")
911
+ subgroup_files[subgroup] = (
912
+ subgroup_npmi_fid,
913
+ subgroup_pmi_fid,
914
+ subgroup_cooc_fid,
915
+ )
916
+ return subgroup_files
917
+
918
+
919
+ ## Input/Output ##
920
+
921
+
922
+ def intersect_dfs(df_dict):
923
+ started = 0
924
+ new_df = None
925
+ for key, df in df_dict.items():
926
+ if df is None:
927
+ continue
928
+ for key2, df2 in df_dict.items():
929
+ if df2 is None:
930
+ continue
931
+ if key == key2:
932
+ continue
933
+ if started:
934
+ new_df = new_df.join(df2, how="inner", lsuffix="1", rsuffix="2")
935
+ else:
936
+ new_df = df.join(df2, how="inner", lsuffix="1", rsuffix="2")
937
+ started = 1
938
+ return new_df.copy()
939
+
940
+
941
+ def write_df(df, df_fid):
942
+ feather.write_feather(df, df_fid)
943
+
944
+
945
+ def write_json(json_dict, json_fid):
946
+ with open(json_fid, "w", encoding="utf-8") as f:
947
+ json.dump(json_dict, f)
948
+
949
+
950
+ def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
951
+ """
952
+ Saves the calculated nPMI statistics to their output files.
953
+ Includes the npmi scores for each identity term, the pmi scores, and the
954
+ co-occurrence counts of the identity term with all the other words
955
+ :param subgroup: Identity term
956
+ :return:
957
+ """
958
+ subgroup_fids = subgroup_files[subgroup]
959
+ subgroup_npmi_fid, subgroup_pmi_fid, subgroup_cooc_fid = subgroup_fids
960
+ subgroup_dfs = subgroup_dict[subgroup]
961
+ subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df = subgroup_dfs
962
+ with open(subgroup_npmi_fid, "w+") as f:
963
+ subgroup_npmi_df.to_csv(f)
964
+ with open(subgroup_pmi_fid, "w+") as f:
965
+ subgroup_pmi_df.to_csv(f)
966
+ with open(subgroup_cooc_fid, "w+") as f:
967
+ subgroup_cooc_df.to_csv(f)
968
+
969
+
970
+ def write_zipf_data(z, zipf_fid):
971
+ zipf_dict = {}
972
+ zipf_dict["xmin"] = int(z.xmin)
973
+ zipf_dict["xmax"] = int(z.xmax)
974
+ zipf_dict["alpha"] = float(z.alpha)
975
+ zipf_dict["ks_distance"] = float(z.distance)
976
+ zipf_dict["p-value"] = float(z.ks_test.pvalue)
977
+ zipf_dict["uniq_counts"] = [int(count) for count in z.uniq_counts]
978
+ zipf_dict["uniq_ranks"] = [int(rank) for rank in z.uniq_ranks]
979
+ with open(zipf_fid, "w+", encoding="utf-8") as f:
980
+ json.dump(zipf_dict, f)
data_measurements/dataset_utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ from dataclasses import asdict
17
+ from os.path import exists
18
+
19
+ import pandas as pd
20
+ from datasets import Dataset, get_dataset_infos, load_dataset, load_from_disk
21
+
22
+ # treating inf values as NaN as well
23
+ pd.set_option("use_inf_as_na", True)
24
+
25
+ ## String names used in Hugging Face dataset configs.
26
+ HF_FEATURE_FIELD = "features"
27
+ HF_LABEL_FIELD = "label"
28
+ HF_DESC_FIELD = "description"
29
+
30
+ CACHE_DIR = "cache_dir"
31
+ ## String names we are using within this code.
32
+ # These are not coming from the stored dataset nor HF config,
33
+ # but rather used as identifiers in our dicts and dataframes.
34
+ OUR_TEXT_FIELD = "text"
35
+ OUR_LABEL_FIELD = "label"
36
+ TOKENIZED_FIELD = "tokenized_text"
37
+ EMBEDDING_FIELD = "embedding"
38
+ LENGTH_FIELD = "length"
39
+ VOCAB = "vocab"
40
+ WORD = "word"
41
+ CNT = "count"
42
+ PROP = "proportion"
43
+ TEXT_NAN_CNT = "text_nan_count"
44
+ TXT_LEN = "text lengths"
45
+ DEDUP_TOT = "dedup_total"
46
+
47
+ _DATASET_LIST = [
48
+ "c4",
49
+ "squad",
50
+ "squad_v2",
51
+ "hate_speech18",
52
+ "hate_speech_offensive",
53
+ "glue",
54
+ "super_glue",
55
+ "wikitext",
56
+ "imdb",
57
+ ]
58
+
59
+ _STREAMABLE_DATASET_LIST = [
60
+ "c4",
61
+ "wikitext",
62
+ ]
63
+
64
+ _MAX_ROWS = 200000
65
+
66
+
67
+ def load_truncated_dataset(
68
+ dataset_name,
69
+ config_name,
70
+ split_name,
71
+ num_rows=_MAX_ROWS,
72
+ cache_name=None,
73
+ use_cache=True,
74
+ use_streaming=True,
75
+ ):
76
+ """
77
+ This function loads the first `num_rows` items of a dataset for a
78
+ given `config_name` and `split_name`.
79
+ If `cache_name` exists, the truncated dataset is loaded from `cache_name`.
80
+ Otherwise, a new truncated dataset is created and immediately saved
81
+ to `cache_name`.
82
+ When the dataset is streamable, we iterate through the first
83
+ `num_rows` examples in streaming mode, write them to a jsonl file,
84
+ then create a new dataset from the json.
85
+ This is the most direct way to make a Dataset from an IterableDataset
86
+ as of datasets version 1.6.1.
87
+ Otherwise, we download the full dataset and select the first
88
+ `num_rows` items
89
+ Args:
90
+ dataset_name (string):
91
+ dataset id in the dataset library
92
+ config_name (string):
93
+ dataset configuration
94
+ split_name (string):
95
+ split name
96
+ num_rows (int):
97
+ number of rows to truncate the dataset to
98
+ cache_name (string):
99
+ name of the cache directory
100
+ use_cache (bool):
101
+ whether to load form the cache if it exists
102
+ use_streaming (bool):
103
+ whether to use streaming when the dataset supports it
104
+ Returns:
105
+ Dataset: the truncated dataset as a Dataset object
106
+ """
107
+ if cache_name is None:
108
+ cache_name = f"{dataset_name}_{config_name}_{split_name}_{num_rows}"
109
+ if exists(cache_name):
110
+ dataset = load_from_disk(cache_name)
111
+ else:
112
+ if use_streaming and dataset_name in _STREAMABLE_DATASET_LIST:
113
+ iterable_dataset = load_dataset(
114
+ dataset_name,
115
+ name=config_name,
116
+ split=split_name,
117
+ streaming=True,
118
+ ).take(num_rows)
119
+ rows = list(iterable_dataset)
120
+ f = open("temp.jsonl", "w", encoding="utf-8")
121
+ for row in rows:
122
+ _ = f.write(json.dumps(row) + "\n")
123
+ f.close()
124
+ dataset = Dataset.from_json(
125
+ "temp.jsonl", features=iterable_dataset.features, split=split_name
126
+ )
127
+ else:
128
+ full_dataset = load_dataset(
129
+ dataset_name,
130
+ name=config_name,
131
+ split=split_name,
132
+ )
133
+ dataset = full_dataset.select(range(num_rows))
134
+ dataset.save_to_disk(cache_name)
135
+ return dataset
136
+
137
+
138
+ def intersect_dfs(df_dict):
139
+ started = 0
140
+ new_df = None
141
+ for key, df in df_dict.items():
142
+ if df is None:
143
+ continue
144
+ for key2, df2 in df_dict.items():
145
+ if df2 is None:
146
+ continue
147
+ if key == key2:
148
+ continue
149
+ if started:
150
+ new_df = new_df.join(df2, how="inner", lsuffix="1", rsuffix="2")
151
+ else:
152
+ new_df = df.join(df2, how="inner", lsuffix="1", rsuffix="2")
153
+ started = 1
154
+ return new_df.copy()
155
+
156
+
157
+ def get_typed_features(features, ftype="string", parents=None):
158
+ """
159
+ Recursively get a list of all features of a certain dtype
160
+ :param features:
161
+ :param ftype:
162
+ :param parents:
163
+ :return: a list of tuples > e.g. ('A', 'B', 'C') for feature example['A']['B']['C']
164
+ """
165
+ if parents is None:
166
+ parents = []
167
+ typed_features = []
168
+ for name, feat in features.items():
169
+ if isinstance(feat, dict):
170
+ if feat.get("dtype", None) == ftype or feat.get("feature", {}).get(
171
+ ("dtype", None) == ftype
172
+ ):
173
+ typed_features += [tuple(parents + [name])]
174
+ elif "feature" in feat:
175
+ if feat["feature"].get("dtype", None) == ftype:
176
+ typed_features += [tuple(parents + [name])]
177
+ elif isinstance(feat["feature"], dict):
178
+ typed_features += get_typed_features(
179
+ feat["feature"], ftype, parents + [name]
180
+ )
181
+ else:
182
+ for k, v in feat.items():
183
+ if isinstance(v, dict):
184
+ typed_features += get_typed_features(
185
+ v, ftype, parents + [name, k]
186
+ )
187
+ elif name == "dtype" and feat == ftype:
188
+ typed_features += [tuple(parents)]
189
+ return typed_features
190
+
191
+
192
+ def get_label_features(features, parents=None):
193
+ """
194
+ Recursively get a list of all features that are ClassLabels
195
+ :param features:
196
+ :param parents:
197
+ :return: pairs of tuples as above and the list of class names
198
+ """
199
+ if parents is None:
200
+ parents = []
201
+ label_features = []
202
+ for name, feat in features.items():
203
+ if isinstance(feat, dict):
204
+ if "names" in feat:
205
+ label_features += [(tuple(parents + [name]), feat["names"])]
206
+ elif "feature" in feat:
207
+ if "names" in feat:
208
+ label_features += [
209
+ (tuple(parents + [name]), feat["feature"]["names"])
210
+ ]
211
+ elif isinstance(feat["feature"], dict):
212
+ label_features += get_label_features(
213
+ feat["feature"], parents + [name]
214
+ )
215
+ else:
216
+ for k, v in feat.items():
217
+ if isinstance(v, dict):
218
+ label_features += get_label_features(v, parents + [name, k])
219
+ elif name == "names":
220
+ label_features += [(tuple(parents), feat)]
221
+ return label_features
222
+
223
+
224
+ # get the info we need for the app sidebar in dict format
225
+ def dictionarize_info(dset_info):
226
+ info_dict = asdict(dset_info)
227
+ res = {
228
+ "config_name": info_dict["config_name"],
229
+ "splits": {
230
+ spl: spl_info["num_examples"]
231
+ for spl, spl_info in info_dict["splits"].items()
232
+ },
233
+ "features": {
234
+ "string": get_typed_features(info_dict["features"], "string"),
235
+ "int32": get_typed_features(info_dict["features"], "int32"),
236
+ "float32": get_typed_features(info_dict["features"], "float32"),
237
+ "label": get_label_features(info_dict["features"]),
238
+ },
239
+ "description": dset_info.description,
240
+ }
241
+ return res
242
+
243
+
244
+ def get_dataset_info_dicts(dataset_id=None):
245
+ """
246
+ Creates a dict from dataset configs.
247
+ Uses the datasets lib's get_dataset_infos
248
+ :return: Dictionary mapping dataset names to their configurations
249
+ """
250
+ if dataset_id != None:
251
+ ds_name_to_conf_dict = {
252
+ dataset_id: {
253
+ config_name: dictionarize_info(config_info)
254
+ for config_name, config_info in get_dataset_infos(dataset_id).items()
255
+ }
256
+ }
257
+ else:
258
+ ds_name_to_conf_dict = {
259
+ ds_id: {
260
+ config_name: dictionarize_info(config_info)
261
+ for config_name, config_info in get_dataset_infos(ds_id).items()
262
+ }
263
+ for ds_id in _DATASET_LIST
264
+ }
265
+ return ds_name_to_conf_dict
266
+
267
+
268
+ # get all instances of a specific field in a dataset
269
+ def extract_field(examples, field_path, new_field_name=None):
270
+ if new_field_name is None:
271
+ new_field_name = "_".join(field_path)
272
+ field_list = []
273
+ # TODO: Breaks the CLI if this isn't checked.
274
+ if isinstance(field_path, str):
275
+ field_path = [field_path]
276
+ item_list = examples[field_path[0]]
277
+ for field_name in field_path[1:]:
278
+ item_list = [
279
+ next_item
280
+ for item in item_list
281
+ for next_item in (
282
+ item[field_name]
283
+ if isinstance(item[field_name], list)
284
+ else [item[field_name]]
285
+ )
286
+ ]
287
+ field_list += [
288
+ field
289
+ for item in item_list
290
+ for field in (item if isinstance(item, list) else [item])
291
+ ]
292
+ return {new_field_name: field_list}
data_measurements/embeddings.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from os.path import exists
17
+ from os.path import join as pjoin
18
+
19
+ import plotly.graph_objects as go
20
+ import torch
21
+ import transformers
22
+ from datasets import load_from_disk
23
+ from tqdm import tqdm
24
+
25
+ from .dataset_utils import EMBEDDING_FIELD, OUR_TEXT_FIELD
26
+
27
+
28
+ def sentence_mean_pooling(model_output, attention_mask):
29
+ token_embeddings = model_output[
30
+ 0
31
+ ] # First element of model_output contains all token embeddings
32
+ input_mask_expanded = (
33
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
34
+ )
35
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
36
+ input_mask_expanded.sum(1), min=1e-9
37
+ )
38
+
39
+
40
+ class Embeddings:
41
+ def __init__(self, dstats, use_cache=False):
42
+ """Item embeddings and clustering"""
43
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
44
+ self.node_list = None
45
+ self.nid_map = None
46
+ self.embeddings_dset = None
47
+ self.fig_tree = None
48
+ self.cached_clusters = {}
49
+ self.dstats = dstats
50
+ self.cache_path = dstats.cache_path
51
+ self.node_list_fid = pjoin(self.cache_path, "node_list.th")
52
+ self.use_cache = use_cache
53
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
54
+ "sentence-transformers/all-mpnet-base-v2"
55
+ )
56
+ self.model = transformers.AutoModel.from_pretrained(
57
+ "sentence-transformers/all-mpnet-base-v2"
58
+ ).to(self.device)
59
+
60
+ def make_text_embeddings(self):
61
+ embeddings_dset_fid = pjoin(self.cache_path, "embeddings_dset")
62
+ if self.use_cache and exists(embeddings_dset_fid):
63
+ self.embeddings_dset = load_from_disk(embeddings_dset_fid)
64
+ else:
65
+ self.embeddings_dset = self.make_embeddings()
66
+ self.embeddings_dset.save_to_disk(embeddings_dset_fid)
67
+
68
+ def make_hierarchical_clustering(self):
69
+ if self.use_cache and exists(self.node_list_fid):
70
+ self.node_list = torch.load(self.node_list_fid)
71
+ else:
72
+ self.make_text_embeddings()
73
+ self.node_list = self.fast_cluster(self.embeddings_dset, EMBEDDING_FIELD)
74
+ torch.save(self.node_list, self.node_list_fid)
75
+ self.nid_map = dict(
76
+ [(node["nid"], nid) for nid, node in enumerate(self.node_list)]
77
+ )
78
+ self.fig_tree = make_tree_plot(self.node_list, self.dstats.text_dset)
79
+
80
+ def compute_sentence_embeddings(self, sentences):
81
+ batch = self.tokenizer(
82
+ sentences, padding=True, truncation=True, return_tensors="pt"
83
+ )
84
+ batch = {k: v.to(self.device) for k, v in batch.items()}
85
+ with torch.no_grad():
86
+ model_output = self.model(**batch)
87
+ sentence_embeds = sentence_mean_pooling(
88
+ model_output, batch["attention_mask"]
89
+ )
90
+ sentence_embeds /= sentence_embeds.norm(dim=-1, keepdim=True)
91
+ return sentence_embeds
92
+
93
+ def make_embeddings(self):
94
+ def batch_embed_sentences(sentences):
95
+ return {
96
+ EMBEDDING_FIELD: [
97
+ embed.tolist()
98
+ for embed in self.compute_sentence_embeddings(
99
+ sentences[OUR_TEXT_FIELD]
100
+ )
101
+ ]
102
+ }
103
+
104
+ text_dset_embeds = self.dstats.text_dset.map(
105
+ batch_embed_sentences,
106
+ batched=True,
107
+ batch_size=32,
108
+ remove_columns=[self.dstats.our_text_field],
109
+ )
110
+
111
+ return text_dset_embeds
112
+
113
+ @staticmethod
114
+ def prepare_merges(embeddings, batch_size, low_thres=0.5):
115
+ top_idx_pre = torch.cat(
116
+ [torch.LongTensor(range(embeddings.shape[0]))[:, None]] * batch_size, dim=1
117
+ )
118
+ top_val_all = torch.Tensor(0, batch_size)
119
+ top_idx_all = torch.LongTensor(0, batch_size)
120
+ n_batches = math.ceil(len(embeddings) / batch_size)
121
+ for b in tqdm(range(n_batches)):
122
+ cos_scores = torch.mm(
123
+ embeddings[b * batch_size : (b + 1) * batch_size], embeddings.t()
124
+ )
125
+ for i in range(cos_scores.shape[0]):
126
+ cos_scores[i, (b * batch_size) + i :] = -1
127
+ top_val_large, top_idx_large = cos_scores.topk(
128
+ k=batch_size, dim=-1, largest=True
129
+ )
130
+ top_val_all = torch.cat([top_val_all, top_val_large], dim=0)
131
+ top_idx_all = torch.cat([top_idx_all, top_idx_large], dim=0)
132
+
133
+ all_merges = torch.cat(
134
+ [
135
+ top_idx_pre[top_val_all > low_thres][:, None],
136
+ top_idx_all[top_val_all > low_thres][:, None],
137
+ ],
138
+ dim=1,
139
+ )
140
+ all_merge_scores = top_val_all[top_val_all > low_thres]
141
+ return (all_merges, all_merge_scores)
142
+
143
+ @staticmethod
144
+ def merge_nodes(nodes, current_thres, previous_thres, all_merges, all_merge_scores):
145
+ merge_ids = (all_merge_scores <= previous_thres) * (
146
+ all_merge_scores > current_thres
147
+ )
148
+ merges = all_merges[merge_ids]
149
+ for a, b in merges.tolist():
150
+ node_a = nodes[a]
151
+ while node_a["parent_id"] != -1:
152
+ node_a = nodes[node_a["parent_id"]]
153
+ node_b = nodes[b]
154
+ while node_b["parent_id"] != -1:
155
+ node_b = nodes[node_b["parent_id"]]
156
+ if node_a["nid"] == node_b["nid"]:
157
+ continue
158
+ else:
159
+ # merge if threshold allows
160
+ if (node_a["depth"] + node_b["depth"]) > 0 and min(
161
+ node_a["merge_threshold"], node_b["merge_threshold"]
162
+ ) == current_thres:
163
+ merge_to = None
164
+ merge_from = None
165
+ if node_a["nid"] < node_b["nid"]:
166
+ merge_from = node_a
167
+ merge_to = node_b
168
+ if node_a["nid"] > node_b["nid"]:
169
+ merge_from = node_b
170
+ merge_to = node_a
171
+ merge_to["depth"] = max(merge_to["depth"], merge_from["depth"])
172
+ merge_to["weight"] += merge_from["weight"]
173
+ merge_to["children_ids"] += (
174
+ merge_from["children_ids"]
175
+ if merge_from["depth"] > 0
176
+ else [merge_from["nid"]]
177
+ )
178
+ for cid in merge_from["children_ids"]:
179
+ nodes[cid]["parent_id"] = merge_to["nid"]
180
+ merge_from["parent_id"] = merge_to["nid"]
181
+ # else new node
182
+ else:
183
+ new_nid = len(nodes)
184
+ new_node = {
185
+ "nid": new_nid,
186
+ "parent_id": -1,
187
+ "depth": max(node_a["depth"], node_b["depth"]) + 1,
188
+ "weight": node_a["weight"] + node_b["weight"],
189
+ "children": [],
190
+ "children_ids": [node_a["nid"], node_b["nid"]],
191
+ "example_ids": [],
192
+ "merge_threshold": current_thres,
193
+ }
194
+ node_a["parent_id"] = new_nid
195
+ node_b["parent_id"] = new_nid
196
+ nodes += [new_node]
197
+ return nodes
198
+
199
+ def finalize_node(self, node, nodes, min_cluster_size):
200
+ node["children"] = sorted(
201
+ [
202
+ self.finalize_node(nodes[cid], nodes, min_cluster_size)
203
+ for cid in node["children_ids"]
204
+ ],
205
+ key=lambda x: x["weight"],
206
+ reverse=True,
207
+ )
208
+ if node["depth"] > 0:
209
+ node["example_ids"] = [
210
+ eid for child in node["children"] for eid in child["example_ids"]
211
+ ]
212
+ node["children"] = [
213
+ child for child in node["children"] if child["weight"] >= min_cluster_size
214
+ ]
215
+ assert node["weight"] == len(node["example_ids"]), print(node)
216
+ return node
217
+
218
+ def fast_cluster(
219
+ self,
220
+ text_dset_embeds,
221
+ embedding_field,
222
+ batch_size=1000,
223
+ min_cluster_size=10,
224
+ low_thres=0.5,
225
+ ):
226
+ embeddings = torch.Tensor(text_dset_embeds[embedding_field])
227
+ batch_size = min(embeddings.shape[0], batch_size)
228
+ all_merges, all_merge_scores = self.prepare_merges(
229
+ embeddings, batch_size, low_thres
230
+ )
231
+ # prepare leaves
232
+ nodes = [
233
+ {
234
+ "nid": nid,
235
+ "parent_id": -1,
236
+ "depth": 0,
237
+ "weight": 1,
238
+ "children": [],
239
+ "children_ids": [],
240
+ "example_ids": [nid],
241
+ "merge_threshold": 1.0,
242
+ }
243
+ for nid in range(embeddings.shape[0])
244
+ ]
245
+ # one level per threshold range
246
+ for i in range(10):
247
+ p_thres = 1 - i * 0.05
248
+ c_thres = 0.95 - i * 0.05
249
+ nodes = self.merge_nodes(
250
+ nodes, c_thres, p_thres, all_merges, all_merge_scores
251
+ )
252
+ # make root
253
+ root_children = [
254
+ node
255
+ for node in nodes
256
+ if node["parent_id"] == -1 and node["weight"] >= min_cluster_size
257
+ ]
258
+ root = {
259
+ "nid": len(nodes),
260
+ "parent_id": -1,
261
+ "depth": max([node["depth"] for node in root_children]) + 1,
262
+ "weight": sum([node["weight"] for node in root_children]),
263
+ "children": [],
264
+ "children_ids": [node["nid"] for node in root_children],
265
+ "example_ids": [],
266
+ "merge_threshold": -1.0,
267
+ }
268
+ nodes += [root]
269
+ for node in root_children:
270
+ node["parent_id"] = root["nid"]
271
+ # finalize tree
272
+ tree = self.finalize_node(root, nodes, min_cluster_size)
273
+ node_list = []
274
+
275
+ def rec_map_nodes(node, node_list):
276
+ node_list += [node]
277
+ for child in node["children"]:
278
+ rec_map_nodes(child, node_list)
279
+
280
+ rec_map_nodes(tree, node_list)
281
+ # get centroids and distances
282
+ for node in node_list:
283
+ node_embeds = embeddings[node["example_ids"]]
284
+ node["centroid"] = node_embeds.sum(dim=0)
285
+ node["centroid"] /= node["centroid"].norm()
286
+ node["centroid_dot_prods"] = torch.mv(node_embeds, node["centroid"])
287
+ node["sorted_examples_centroid"] = sorted(
288
+ [
289
+ (eid, edp.item())
290
+ for eid, edp in zip(node["example_ids"], node["centroid_dot_prods"])
291
+ ],
292
+ key=lambda x: x[1],
293
+ reverse=True,
294
+ )
295
+ return node_list
296
+
297
+ def find_cluster_beam(self, sentence, beam_size=20):
298
+ """
299
+ This function finds the `beam_size` lef clusters that are closest to the
300
+ proposed sentence and returns the full path from the root to the cluster
301
+ along with the dot product between the sentence embedding and the
302
+ cluster centroid
303
+ Args:
304
+ sentence (string): input sentence for which to find clusters
305
+ beam_size (int): this is a beam size algorithm to explore the tree
306
+ Returns:
307
+ [([int], float)]: list of (path_from_root, score) sorted by score
308
+ """
309
+ embed = self.compute_sentence_embeddings([sentence])[0].to("cpu")
310
+ active_paths = [([0], torch.dot(embed, self.node_list[0]["centroid"]).item())]
311
+ finished_paths = []
312
+ children_ids_list = [
313
+ [
314
+ self.nid_map[nid]
315
+ for nid in self.node_list[path[-1]]["children_ids"]
316
+ if nid in self.nid_map
317
+ ]
318
+ for path, score in active_paths
319
+ ]
320
+ while len(active_paths) > 0:
321
+ next_ids = sorted(
322
+ [
323
+ (
324
+ beam_id,
325
+ nid,
326
+ torch.dot(embed, self.node_list[nid]["centroid"]).item(),
327
+ )
328
+ for beam_id, children_ids in enumerate(children_ids_list)
329
+ for nid in children_ids
330
+ ],
331
+ key=lambda x: x[2],
332
+ reverse=True,
333
+ )[:beam_size]
334
+ paths = [
335
+ (active_paths[beam_id][0] + [next_id], score)
336
+ for beam_id, next_id, score in next_ids
337
+ ]
338
+ active_paths = []
339
+ for path, score in paths:
340
+ if (
341
+ len(
342
+ [
343
+ nid
344
+ for nid in self.node_list[path[-1]]["children_ids"]
345
+ if nid in self.nid_map
346
+ ]
347
+ )
348
+ > 0
349
+ ):
350
+ active_paths += [(path, score)]
351
+ else:
352
+ finished_paths += [(path, score)]
353
+ children_ids_list = [
354
+ [
355
+ self.nid_map[nid]
356
+ for nid in self.node_list[path[-1]]["children_ids"]
357
+ if nid in self.nid_map
358
+ ]
359
+ for path, score in active_paths
360
+ ]
361
+ return sorted(
362
+ finished_paths,
363
+ key=lambda x: x[-1],
364
+ reverse=True,
365
+ )[:beam_size]
366
+
367
+
368
+ def make_tree_plot(node_list, text_dset):
369
+ nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
370
+
371
+ for nid, node in enumerate(node_list):
372
+ node["label"] = node.get(
373
+ "label",
374
+ f"{nid:2d} - {node['weight']:5d} items <br>"
375
+ + "<br>".join(
376
+ [
377
+ "> " + txt[:64] + ("..." if len(txt) >= 63 else "")
378
+ for txt in list(
379
+ set(text_dset.select(node["example_ids"])[OUR_TEXT_FIELD])
380
+ )[:5]
381
+ ]
382
+ ),
383
+ )
384
+
385
+ # make plot nodes
386
+ # TODO: something more efficient than set to remove duplicates
387
+ labels = [node["label"] for node in node_list]
388
+
389
+ root = node_list[0]
390
+ root["X"] = 0
391
+ root["Y"] = 0
392
+
393
+ def rec_make_coordinates(node):
394
+ total_weight = 0
395
+ add_weight = len(node["example_ids"]) - sum(
396
+ [child["weight"] for child in node["children"]]
397
+ )
398
+ for child in node["children"]:
399
+ child["X"] = node["X"] + total_weight
400
+ child["Y"] = node["Y"] - 1
401
+ total_weight += child["weight"] + add_weight / len(node["children"])
402
+ rec_make_coordinates(child)
403
+
404
+ rec_make_coordinates(root)
405
+
406
+ E = [] # list of edges
407
+ Xn = []
408
+ Yn = []
409
+ Xe = []
410
+ Ye = []
411
+ for nid, node in enumerate(node_list):
412
+ Xn += [node["X"]]
413
+ Yn += [node["Y"]]
414
+ for child in node["children"]:
415
+ E += [(nid, nid_map[child["nid"]])]
416
+ Xe += [node["X"], child["X"], None]
417
+ Ye += [node["Y"], child["Y"], None]
418
+
419
+ # make figure
420
+ fig = go.Figure()
421
+ fig.add_trace(
422
+ go.Scatter(
423
+ x=Xe,
424
+ y=Ye,
425
+ mode="lines",
426
+ line=dict(color="rgb(210,210,210)", width=1),
427
+ hoverinfo="none",
428
+ )
429
+ )
430
+ fig.add_trace(
431
+ go.Scatter(
432
+ x=Xn,
433
+ y=Yn,
434
+ mode="markers",
435
+ name="nodes",
436
+ marker=dict(
437
+ symbol="circle-dot",
438
+ size=18,
439
+ color="#6175c1",
440
+ line=dict(color="rgb(50,50,50)", width=1)
441
+ # '#DB4551',
442
+ ),
443
+ text=labels,
444
+ hoverinfo="text",
445
+ opacity=0.8,
446
+ )
447
+ )
448
+ return fig
data_measurements/npmi.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import warnings
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+ from sklearn.preprocessing import MultiLabelBinarizer
21
+
22
+ # Might be nice to print to log instead? Happens when we drop closed class.
23
+ warnings.filterwarnings(action="ignore", category=UserWarning)
24
+ # When we divide by 0 in log
25
+ np.seterr(divide="ignore")
26
+
27
+ # treating inf values as NaN as well
28
+ pd.set_option("use_inf_as_na", True)
29
+
30
+ logs = logging.getLogger(__name__)
31
+ logs.setLevel(logging.INFO)
32
+ logs.propagate = False
33
+
34
+ if not logs.handlers:
35
+
36
+ # Logging info to log file
37
+ file = logging.FileHandler("./log_files/npmi.log")
38
+ fileformat = logging.Formatter("%(asctime)s:%(message)s")
39
+ file.setLevel(logging.INFO)
40
+ file.setFormatter(fileformat)
41
+
42
+ # Logging debug messages to stream
43
+ stream = logging.StreamHandler()
44
+ streamformat = logging.Formatter("[data_measurements_tool] %(message)s")
45
+ stream.setLevel(logging.WARNING)
46
+ stream.setFormatter(streamformat)
47
+
48
+ logs.addHandler(file)
49
+ logs.addHandler(stream)
50
+
51
+ _NUM_BATCHES = 500
52
+
53
+
54
+ class nPMI:
55
+ # TODO: Expand beyond pairwise
56
+ def __init__(
57
+ self,
58
+ vocab_counts_df,
59
+ tokenized_df,
60
+ tokenized_col_name="tokenized_text",
61
+ num_batches=_NUM_BATCHES,
62
+ ):
63
+ logs.info("Initiating npmi class.")
64
+ logs.info("vocab is")
65
+ logs.info(vocab_counts_df)
66
+ self.vocab_counts_df = vocab_counts_df
67
+ logs.info("tokenized is")
68
+ self.tokenized_df = tokenized_df
69
+ logs.info(self.tokenized_df)
70
+ self.tokenized_col_name = tokenized_col_name
71
+ # self.mlb_list holds num batches x num_sentences
72
+ self.mlb_list = []
73
+
74
+ def binarize_words_in_sentence(self):
75
+ logs.info("Creating co-occurrence matrix for PMI calculations.")
76
+ batches = np.linspace(0, self.tokenized_df.shape[0], _NUM_BATCHES).astype(int)
77
+ i = 0
78
+ # Creates list of size (# batches x # sentences)
79
+ while i < len(batches) - 1:
80
+ # Makes a sparse matrix (shape: # sentences x # words),
81
+ # with the occurrence of each word per sentence.
82
+ mlb = MultiLabelBinarizer(classes=self.vocab_counts_df.index)
83
+ logs.info(
84
+ "%s of %s sentence binarize batches." % (str(i), str(len(batches)))
85
+ )
86
+ # Returns series: batch size x num_words
87
+ mlb_series = mlb.fit_transform(
88
+ self.tokenized_df[self.tokenized_col_name][batches[i] : batches[i + 1]]
89
+ )
90
+ i += 1
91
+ self.mlb_list.append(mlb_series)
92
+
93
+ def calc_cooccurrences(self, subgroup, subgroup_idx):
94
+ initialize = True
95
+ coo_df = None
96
+ # Big computation here! Should only happen once.
97
+ logs.info(
98
+ "Approaching big computation! Here, we binarize all words in the sentences, making a sparse matrix of sentences."
99
+ )
100
+ if not self.mlb_list:
101
+ self.binarize_words_in_sentence()
102
+ for batch_id in range(len(self.mlb_list)):
103
+ logs.info(
104
+ "%s of %s co-occurrence count batches"
105
+ % (str(batch_id), str(len(self.mlb_list)))
106
+ )
107
+ # List of all the sentences (list of vocab) in that batch
108
+ batch_sentence_row = self.mlb_list[batch_id]
109
+ # Dataframe of # sentences in batch x vocabulary size
110
+ sent_batch_df = pd.DataFrame(batch_sentence_row)
111
+ # logs.info('sent batch df is')
112
+ # logs.info(sent_batch_df)
113
+ # Subgroup counts per-sentence for the given batch
114
+ subgroup_df = sent_batch_df[subgroup_idx]
115
+ subgroup_df.columns = [subgroup]
116
+ # Remove the sentences where the count of the subgroup is 0.
117
+ # This way we have less computation & resources needs.
118
+ subgroup_df = subgroup_df[subgroup_df > 0]
119
+ logs.info("Removing 0 counts, subgroup_df is")
120
+ logs.info(subgroup_df)
121
+ mlb_subgroup_only = sent_batch_df[sent_batch_df[subgroup_idx] > 0]
122
+ logs.info("mlb subgroup only is")
123
+ logs.info(mlb_subgroup_only)
124
+ # Create cooccurrence matrix for the given subgroup and all words.
125
+ logs.info("Now we do the T.dot approach for co-occurrences")
126
+ batch_coo_df = pd.DataFrame(mlb_subgroup_only.T.dot(subgroup_df))
127
+
128
+ # Creates a batch-sized dataframe of co-occurrence counts.
129
+ # Note these could just be summed rather than be batch size.
130
+ if initialize:
131
+ coo_df = batch_coo_df
132
+ else:
133
+ coo_df = coo_df.add(batch_coo_df, fill_value=0)
134
+ logs.info("coo_df is")
135
+ logs.info(coo_df)
136
+ initialize = False
137
+ logs.info("Returning co-occurrence matrix")
138
+ logs.info(coo_df)
139
+ return pd.DataFrame(coo_df)
140
+
141
+ def calc_paired_metrics(self, subgroup_pair, subgroup_npmi_dict):
142
+ """
143
+ Calculates nPMI metrics between paired subgroups.
144
+ Special handling for a subgroup paired with itself.
145
+ :param subgroup_npmi_dict:
146
+ :return:
147
+ """
148
+ paired_results_dict = {"npmi": {}, "pmi": {}, "count": {}}
149
+ # Canonical ordering. This is done previously, but just in case...
150
+ subgroup1, subgroup2 = sorted(subgroup_pair)
151
+ vocab_cooc_df1, pmi_df1, npmi_df1 = subgroup_npmi_dict[subgroup1]
152
+ logs.info("vocab cooc")
153
+ logs.info(vocab_cooc_df1)
154
+ if subgroup1 == subgroup2:
155
+ shared_npmi_df = npmi_df1
156
+ shared_pmi_df = pmi_df1
157
+ shared_vocab_cooc_df = vocab_cooc_df1
158
+ else:
159
+ vocab_cooc_df2, pmi_df2, npmi_df2 = subgroup_npmi_dict[subgroup2]
160
+ logs.info("vocab cooc2")
161
+ logs.info(vocab_cooc_df2)
162
+ # Note that lsuffix and rsuffix should not come into play.
163
+ shared_npmi_df = npmi_df1.join(
164
+ npmi_df2, how="inner", lsuffix="1", rsuffix="2"
165
+ )
166
+ shared_pmi_df = pmi_df1.join(pmi_df2, how="inner", lsuffix="1", rsuffix="2")
167
+ shared_vocab_cooc_df = vocab_cooc_df1.join(
168
+ vocab_cooc_df2, how="inner", lsuffix="1", rsuffix="2"
169
+ )
170
+ shared_vocab_cooc_df = shared_vocab_cooc_df.dropna()
171
+ shared_vocab_cooc_df = shared_vocab_cooc_df[
172
+ shared_vocab_cooc_df.index.notnull()
173
+ ]
174
+ logs.info("shared npmi df")
175
+ logs.info(shared_npmi_df)
176
+ logs.info("shared vocab df")
177
+ logs.info(shared_vocab_cooc_df)
178
+ npmi_bias = (
179
+ shared_npmi_df[subgroup1 + "-npmi"] - shared_npmi_df[subgroup2 + "-npmi"]
180
+ )
181
+ paired_results_dict["npmi-bias"] = npmi_bias.dropna()
182
+ paired_results_dict["npmi"] = shared_npmi_df.dropna()
183
+ paired_results_dict["pmi"] = shared_pmi_df.dropna()
184
+ paired_results_dict["count"] = shared_vocab_cooc_df.dropna()
185
+ return paired_results_dict
186
+
187
+ def calc_metrics(self, subgroup):
188
+ # Index of the subgroup word in the sparse vector
189
+ subgroup_idx = self.vocab_counts_df.index.get_loc(subgroup)
190
+ logs.info("Calculating co-occurrences...")
191
+ df_coo = self.calc_cooccurrences(subgroup, subgroup_idx)
192
+ vocab_cooc_df = self.set_idx_cols(df_coo, subgroup)
193
+ logs.info(vocab_cooc_df)
194
+ logs.info("Calculating PMI...")
195
+ pmi_df = self.calc_PMI(vocab_cooc_df, subgroup)
196
+ logs.info(pmi_df)
197
+ logs.info("Calculating nPMI...")
198
+ npmi_df = self.calc_nPMI(pmi_df, vocab_cooc_df, subgroup)
199
+ logs.info(npmi_df)
200
+ return vocab_cooc_df, pmi_df, npmi_df
201
+
202
+ def set_idx_cols(self, df_coo, subgroup):
203
+ """
204
+ :param df_coo: Co-occurrence counts for subgroup, length is num_words
205
+ :return:
206
+ """
207
+ count_df = df_coo.set_index(self.vocab_counts_df.index)
208
+ count_df.columns = [subgroup + "-count"]
209
+ count_df[subgroup + "-count"] = count_df[subgroup + "-count"].astype(int)
210
+ return count_df
211
+
212
+ def calc_PMI(self, vocab_cooc_df, subgroup):
213
+ """
214
+ # PMI(x;y) = h(y) - h(y|x)
215
+ # = h(subgroup) - h(subgroup|word)
216
+ # = log (p(subgroup|word) / p(subgroup))
217
+ # nPMI additionally divides by -log(p(x,y)) = -log(p(x|y)p(y))
218
+ """
219
+ # Calculation of p(subgroup)
220
+ subgroup_prob = self.vocab_counts_df.loc[subgroup]["proportion"]
221
+ # Calculation of p(subgroup|word) = count(subgroup,word) / count(word)
222
+ # Because the inidices match (the vocab words),
223
+ # this division doesn't need to specify the index (I think?!)
224
+ p_subgroup_g_word = (
225
+ vocab_cooc_df[subgroup + "-count"] / self.vocab_counts_df["count"]
226
+ )
227
+ logs.info("p_subgroup_g_word is")
228
+ logs.info(p_subgroup_g_word)
229
+ pmi_df = pd.DataFrame()
230
+ pmi_df[subgroup + "-pmi"] = np.log(p_subgroup_g_word / subgroup_prob)
231
+ # Note: A potentially faster solution for adding count, npmi,
232
+ # can be based on this zip idea:
233
+ # df_test['size_kb'], df_test['size_mb'], df_test['size_gb'] =
234
+ # zip(*df_test['size'].apply(sizes))
235
+ return pmi_df.dropna()
236
+
237
+ def calc_nPMI(self, pmi_df, vocab_cooc_df, subgroup):
238
+ """
239
+ # nPMI additionally divides by -log(p(x,y)) = -log(p(x|y)p(y))
240
+ # = -log(p(word|subgroup)p(word))
241
+ """
242
+ p_word_g_subgroup = vocab_cooc_df[subgroup + "-count"] / sum(
243
+ vocab_cooc_df[subgroup + "-count"]
244
+ )
245
+ p_word = pmi_df.apply(
246
+ lambda x: self.vocab_counts_df.loc[x.name]["proportion"], axis=1
247
+ )
248
+ normalize_pmi = -np.log(p_word_g_subgroup * p_word)
249
+ npmi_df = pd.DataFrame()
250
+ npmi_df[subgroup + "-npmi"] = pmi_df[subgroup + "-pmi"] / normalize_pmi
251
+ return npmi_df.dropna()
data_measurements/streamlit_utils.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import statistics
16
+
17
+ import pandas as pd
18
+ import seaborn as sns
19
+ import streamlit as st
20
+ from st_aggrid import AgGrid, GridOptionsBuilder
21
+
22
+ from .dataset_utils import HF_DESC_FIELD, HF_FEATURE_FIELD, HF_LABEL_FIELD
23
+
24
+
25
+ def sidebar_header():
26
+ st.sidebar.markdown(
27
+ """
28
+ This demo showcases the [dataset metrics as we develop them](https://github.com/huggingface/DataMeasurements).
29
+ Right now this has:
30
+ - dynamic loading of datasets in the lib
31
+ - fetching config and info without downloading the dataset
32
+ - propose the list of candidate text and label features to select
33
+ We are still working on:
34
+ - implementing all the current tools
35
+ """,
36
+ unsafe_allow_html=True,
37
+ )
38
+
39
+
40
+ def sidebar_selection(ds_name_to_dict, column_id):
41
+ ds_names = list(ds_name_to_dict.keys())
42
+ with st.sidebar.expander(f"Choose dataset and field {column_id}", expanded=True):
43
+ # choose a dataset to analyze
44
+ ds_name = st.selectbox(
45
+ f"Choose dataset to explore{column_id}:",
46
+ ds_names,
47
+ index=ds_names.index("hate_speech18"),
48
+ )
49
+ # choose a config to analyze
50
+ ds_configs = ds_name_to_dict[ds_name]
51
+ config_names = list(ds_configs.keys())
52
+ config_name = st.selectbox(
53
+ f"Choose configuration{column_id}:",
54
+ config_names,
55
+ index=0,
56
+ )
57
+ # choose a subset of num_examples
58
+ # TODO: Handling for multiple text features
59
+ ds_config = ds_configs[config_name]
60
+ text_features = ds_config[HF_FEATURE_FIELD]["string"]
61
+ # TODO @yacine: Explain what this is doing and why eg tp[0] could = "id"
62
+ text_field = st.selectbox(
63
+ f"Which text feature from the{column_id} dataset would you like to analyze?",
64
+ [("text",)]
65
+ if ds_name == "c4"
66
+ else [tp for tp in text_features if tp[0] != "id"],
67
+ )
68
+ # Choose a split and dataset size
69
+ avail_splits = list(ds_config["splits"].keys())
70
+ # 12.Nov note: Removing "test" because those should not be examined
71
+ # without discussion of pros and cons, which we haven't done yet.
72
+ if "test" in avail_splits:
73
+ avail_splits.remove("test")
74
+ split = st.selectbox(
75
+ f"Which split from the{column_id} dataset would you like to analyze?",
76
+ avail_splits,
77
+ index=0,
78
+ )
79
+ label_field, label_names = (
80
+ ds_name_to_dict[ds_name][config_name][HF_FEATURE_FIELD][HF_LABEL_FIELD][0]
81
+ if len(
82
+ ds_name_to_dict[ds_name][config_name][HF_FEATURE_FIELD][HF_LABEL_FIELD]
83
+ )
84
+ > 0
85
+ else ((), [])
86
+ )
87
+ return {
88
+ "dset_name": ds_name,
89
+ "dset_config": config_name,
90
+ "split_name": split,
91
+ "text_field": text_field,
92
+ "label_field": label_field,
93
+ "label_names": label_names,
94
+ }
95
+
96
+
97
+ def expander_header(dstats, ds_name_to_dict, column_id):
98
+ with st.expander(f"Dataset Description{column_id}"):
99
+ st.markdown(
100
+ ds_name_to_dict[dstats.dset_name][dstats.dset_config][HF_DESC_FIELD]
101
+ )
102
+ st.dataframe(dstats.get_dataset_peek())
103
+
104
+
105
+ def expander_general_stats(dstats, top_n, column_id):
106
+ with st.expander(f"General Text Statistics{column_id}"):
107
+ st.caption(
108
+ "Use this widget to check whether the terms you see most represented in the dataset make sense for the goals of the dataset."
109
+ )
110
+ st.markdown(
111
+ "There are {0} total words".format(str(len(dstats.vocab_counts_df)))
112
+ )
113
+ st.markdown(
114
+ "There are {0} words after removing closed "
115
+ "class words".format(str(len(dstats.vocab_counts_filtered_df)))
116
+ )
117
+ sorted_top_vocab_df = dstats.vocab_counts_filtered_df.sort_values(
118
+ "count", ascending=False
119
+ ).head(top_n)
120
+ st.markdown(
121
+ "The most common [open class words](https://dictionary.apa.org/open-class-words) and their counts are: "
122
+ )
123
+ st.dataframe(sorted_top_vocab_df)
124
+ st.markdown(
125
+ "There are {0} missing values in the dataset.".format(
126
+ str(dstats.text_nan_count)
127
+ )
128
+ )
129
+ st.markdown(
130
+ "There are {0} duplicate items in the dataset. For more information about the duplicates, click the 'Duplicates' tab below.".format(
131
+ str(dstats.dedup_total)
132
+ )
133
+ )
134
+
135
+
136
+ ### Show the label distribution from the datasets
137
+ def expander_label_distribution(label_df, fig_labels, column_id):
138
+ with st.expander(f"Label Distribution{column_id}", expanded=False):
139
+ st.caption(
140
+ "Use this widget to see how balanced the labels in your dataset are."
141
+ )
142
+ if label_df is not None:
143
+ st.plotly_chart(fig_labels, use_container_width=True)
144
+ else:
145
+ st.markdown("No labels were found in the dataset")
146
+
147
+
148
+ def expander_text_lengths(
149
+ tokenized_df,
150
+ fig_tok_length,
151
+ avg_length,
152
+ std_length,
153
+ text_field_name,
154
+ length_field_name,
155
+ column_id,
156
+ ):
157
+ _TEXT_LENGTH_CAPTION = (
158
+ "Use this widget to identify outliers, particularly suspiciously long outliers."
159
+ )
160
+ with st.expander(f"Text Lengths{column_id}", expanded=False):
161
+ st.caption(_TEXT_LENGTH_CAPTION)
162
+ st.markdown(
163
+ "Below, you can see how the lengths of the text instances in your dataset are distributed."
164
+ )
165
+ st.markdown(
166
+ "Any unexpected peaks or valleys in the distribution may help to identify data instances you want to remove or augment."
167
+ )
168
+ st.markdown(
169
+ "### Here is the relative frequency of different text lengths in your dataset:"
170
+ )
171
+ st.plotly_chart(fig_tok_length, use_container_width=True)
172
+ data = tokenized_df[[length_field_name, text_field_name]].sort_values(
173
+ by=["length"], ascending=True
174
+ )
175
+ st.markdown(
176
+ "The average length of text instances is **"
177
+ + str(avg_length)
178
+ + " words**, with a standard deviation of **"
179
+ + str(std_length)
180
+ + "**."
181
+ )
182
+
183
+ start_id_show_lengths = st.slider(
184
+ f"Show the shortest sentences{column_id} starting at:",
185
+ 0,
186
+ len(data["length"].unique()),
187
+ value=0,
188
+ step=1,
189
+ )
190
+ st.dataframe(data[data["length"] == start_id_show_lengths].set_index("length"))
191
+
192
+
193
+ ### Third, use a sentence embedding model
194
+ def expander_text_embeddings(
195
+ text_dset, fig_tree, node_list, embeddings, text_field, column_id
196
+ ):
197
+ with st.expander(f"Text Embedding Clusters{column_id}", expanded=False):
198
+ _EMBEDDINGS_CAPTION = """
199
+ ### Hierarchical Clustering of Text Fields
200
+ Taking in the diversity of text represented in a dataset can be
201
+ challenging when it is made up of hundreds to thousands of sentences.
202
+ Grouping these text items based on a measure of similarity can help
203
+ users gain some insights into their distribution.
204
+ The following figure shows a hierarchical clustering of the text fields
205
+ in the dataset based on a
206
+ [Sentence-Transformer](https://hf.co/sentence-transformers/all-mpnet-base-v2)
207
+ model. Clusters are merged if any of the embeddings in cluster A has a
208
+ dot product with any of the embeddings or with the centroid of cluster B
209
+ higher than a threshold (one threshold per level, from 0.5 to 0.95).
210
+ To explore the clusters, you can:
211
+ - hover over a node to see the 5 most representative examples (deduplicated)
212
+ - enter an example in the text box below to see which clusters it is most similar to
213
+ - select a cluster by ID to show all of its examples
214
+ """
215
+ st.markdown(_EMBEDDINGS_CAPTION)
216
+ st.plotly_chart(fig_tree, use_container_width=True)
217
+ st.markdown("---\n")
218
+ if st.checkbox(
219
+ label="Enter text to see nearest clusters",
220
+ key=f"search_clusters_{column_id}",
221
+ ):
222
+ compare_example = st.text_area(
223
+ label="Enter some text here to see which of the clusters in the dataset it is closest to",
224
+ key=f"search_cluster_input_{column_id}",
225
+ )
226
+ if compare_example != "":
227
+ paths_to_leaves = embeddings.cached_clusters.get(
228
+ compare_example,
229
+ embeddings.find_cluster_beam(compare_example, beam_size=50),
230
+ )
231
+ clusters_intro = ""
232
+ if paths_to_leaves[0][1] < 0.3:
233
+ clusters_intro += (
234
+ "**Warning: no close clusters found (best score <0.3). **"
235
+ )
236
+ clusters_intro += "The closest clusters to the text entered aboce are:"
237
+ st.markdown(clusters_intro)
238
+ for path, score in paths_to_leaves[:5]:
239
+ example = text_dset[
240
+ node_list[path[-1]]["sorted_examples_centroid"][0][0]
241
+ ][text_field][:256]
242
+ st.write(
243
+ f"Cluster {path[-1]:5d} | Score: {score:.3f} \n Example: {example}"
244
+ )
245
+ show_node_default = paths_to_leaves[0][0][-1]
246
+ else:
247
+ show_node_default = len(node_list) // 2
248
+ else:
249
+ show_node_default = len(node_list) // 2
250
+ st.markdown("---\n")
251
+ show_node = st.selectbox(
252
+ f"Choose a leaf node to explore in the{column_id} dataset:",
253
+ range(len(node_list)),
254
+ index=show_node_default,
255
+ )
256
+ node = node_list[show_node]
257
+ start_id = st.slider(
258
+ f"Show closest sentences in cluster to the centroid{column_id} starting at index:",
259
+ 0,
260
+ len(node["sorted_examples_centroid"]) - 5,
261
+ value=0,
262
+ step=5,
263
+ )
264
+ for sid, sim in node["sorted_examples_centroid"][start_id : start_id + 5]:
265
+ # only show the first 4 lines and the first 10000 characters
266
+ show_text = text_dset[sid][text_field][:10000]
267
+ show_text = "\n".join(show_text.split("\n")[:4])
268
+ st.text(f"{sim:.3f} \t {show_text}")
269
+
270
+
271
+ ### Then, show duplicates
272
+ def expander_text_duplicates(dedup_df, column_id):
273
+ with st.expander(f"Text Duplicates{column_id}", expanded=False):
274
+ st.caption(
275
+ "Use this widget to identify text strings that appear more than once."
276
+ )
277
+ st.markdown(
278
+ "A model's training and testing may be negatively affected by unwarranted duplicates ([Lee et al., 2021](https://arxiv.org/abs/2107.06499))."
279
+ )
280
+ dedup_df["count"] = dedup_df["count"] + 1
281
+ st.markdown("------")
282
+ st.write(
283
+ "### Here is the list of all the duplicated items and their counts in your dataset:"
284
+ )
285
+ # Eh...adding 1 because otherwise it looks too weird for duplicate counts when the value is just 1.
286
+ if len(dedup_df) == 0:
287
+ st.write("There are no duplicates in this dataset! 🥳")
288
+ else:
289
+ gb = GridOptionsBuilder.from_dataframe(dedup_df)
290
+ gb.configure_column(
291
+ f"text{column_id}",
292
+ wrapText=True,
293
+ resizable=True,
294
+ autoHeight=True,
295
+ min_column_width=85,
296
+ use_container_width=True,
297
+ )
298
+ go = gb.build()
299
+ AgGrid(dedup_df, gridOptions=go)
300
+
301
+
302
+ def expander_npmi_description(min_vocab):
303
+ _NPMI_CAPTION = (
304
+ "Use this widget to identify problematic biases and stereotypes in your data."
305
+ )
306
+ _NPMI_CAPTION1 = """
307
+ nPMI scores for a word help to identify potentially
308
+ problematic associations, ranked by how close the association is."""
309
+ _NPMI_CAPTION2 = """
310
+ nPMI bias scores for paired words help to identify how word
311
+ associations are skewed between the selected selected words
312
+ ([Aka et al., 2021](https://arxiv.org/abs/2103.03417)).
313
+ """
314
+
315
+ st.caption(_NPMI_CAPTION)
316
+ st.markdown(_NPMI_CAPTION1)
317
+ st.markdown(_NPMI_CAPTION2)
318
+ st.markdown(" ")
319
+ st.markdown(
320
+ "You can select from gender and sexual orientation "
321
+ "identity terms that appear in the dataset at least %s "
322
+ "times." % min_vocab
323
+ )
324
+ st.markdown(
325
+ "The resulting ranked words are those that co-occur with both "
326
+ "identity terms. "
327
+ )
328
+ st.markdown(
329
+ "The more *positive* the score, the more associated the word is with the first identity term. "
330
+ "The more *negative* the score, the more associated the word is with the second identity term."
331
+ )
332
+
333
+
334
+ ### Finally, show Zipf stuff
335
+ def expander_zipf(z, zipf_fig, column_id):
336
+ _ZIPF_CAPTION = """This shows how close the observed language is to an ideal
337
+ natural language distribution following [Zipf's law](https://en.wikipedia.org/wiki/Zipf%27s_law),
338
+ calculated by minimizing the [Kolmogorov-Smirnov (KS) statistic](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test)."""
339
+
340
+ powerlaw_eq = r"""p(x) \propto x^{- \alpha}"""
341
+ zipf_summary = (
342
+ "The optimal alpha based on this dataset is: **"
343
+ + str(round(z.alpha, 2))
344
+ + "**, with a KS distance of: **"
345
+ + str(round(z.distance, 2))
346
+ )
347
+ zipf_summary += (
348
+ "**. This was fit with a minimum rank value of: **"
349
+ + str(int(z.xmin))
350
+ + "**, which is the optimal rank *beyond which* the scaling regime of the power law fits best."
351
+ )
352
+
353
+ alpha_warning = "Your alpha value is a bit on the high side, which means that the distribution over words in this dataset is a bit unnatural. This could be due to non-language items throughout the dataset."
354
+ xmin_warning = "The minimum rank for this fit is a bit on the high side, which means that the frequencies of your most common words aren't distributed as would be expected by Zipf's law."
355
+ fit_results_table = pd.DataFrame.from_dict(
356
+ {
357
+ r"Alpha:": [str("%.2f" % z.alpha)],
358
+ "KS distance:": [str("%.2f" % z.distance)],
359
+ "Min rank:": [str("%s" % int(z.xmin))],
360
+ },
361
+ columns=["Results"],
362
+ orient="index",
363
+ )
364
+ fit_results_table.index.name = column_id
365
+ with st.expander(
366
+ f"Vocabulary Distribution{column_id}: Zipf's Law Fit", expanded=False
367
+ ):
368
+ st.caption(
369
+ "Use this widget for the counts of different words in your dataset, measuring the difference between the observed count and the expected count under Zipf's law."
370
+ )
371
+ st.markdown(_ZIPF_CAPTION)
372
+ st.write(
373
+ """
374
+ A Zipfian distribution follows the power law: $p(x) \propto x^{-α}$
375
+ with an ideal α value of 1."""
376
+ )
377
+ st.markdown(
378
+ "In general, an alpha greater than 2 or a minimum rank greater than 10 (take with a grain of salt) means that your distribution is relativaly _unnatural_ for natural language. This can be a sign of mixed artefacts in the dataset, such as HTML markup."
379
+ )
380
+ st.markdown(
381
+ "Below, you can see the counts of each word in your dataset vs. the expected number of counts following a Zipfian distribution."
382
+ )
383
+ st.markdown("-----")
384
+ st.write("### Here is your dataset's Zipf results:")
385
+ st.dataframe(fit_results_table)
386
+ st.write(zipf_summary)
387
+ # TODO: Nice UI version of the content in the comments.
388
+ # st.markdown("\nThe KS test p-value is < %.2f" % z.ks_test.pvalue)
389
+ # if z.ks_test.pvalue < 0.01:
390
+ # st.markdown(
391
+ # "\n Great news! Your data fits a powerlaw with a minimum KS " "distance of %.4f" % z.distance)
392
+ # else:
393
+ # st.markdown("\n Sadly, your data does not fit a powerlaw. =(")
394
+ # st.markdown("Checking the goodness of fit of our observed distribution")
395
+ # st.markdown("to the hypothesized power law distribution")
396
+ # st.markdown("using a Kolmogorov–Smirnov (KS) test.")
397
+ st.plotly_chart(zipf_fig, use_container_width=True)
398
+ if z.alpha > 2:
399
+ st.markdown(alpha_warning)
400
+ if z.xmin > 5:
401
+ st.markdown(xmin_warning)
402
+
403
+
404
+ ### Finally finally finally, show nPMI stuff.
405
+ def npmi_widget(column_id, available_terms, npmi_stats, min_vocab, use_cache=False):
406
+ """
407
+ Part of the main app, but uses a user interaction so pulled out as its own f'n.
408
+ :param use_cache:
409
+ :param column_id:
410
+ :param npmi_stats:
411
+ :param min_vocab:
412
+ :return:
413
+ """
414
+ with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
415
+ if len(available_terms) > 0:
416
+ expander_npmi_description(min_vocab)
417
+ st.markdown("-----")
418
+ term1 = st.selectbox(
419
+ f"What is the first term you want to select?{column_id}",
420
+ available_terms,
421
+ )
422
+ term2 = st.selectbox(
423
+ f"What is the second term you want to select?{column_id}",
424
+ reversed(available_terms),
425
+ )
426
+ # We calculate/grab nPMI data based on a canonical (alphabetic)
427
+ # subgroup ordering.
428
+ subgroup_pair = sorted([term1, term2])
429
+ try:
430
+ joint_npmi_df = npmi_stats.load_or_prepare_joint_npmi(subgroup_pair)
431
+ npmi_show(joint_npmi_df)
432
+ except KeyError:
433
+ st.markdown(
434
+ "**WARNING!** The nPMI for these terms has not been pre-computed, please re-run caching."
435
+ )
436
+ else:
437
+ st.markdown(
438
+ "No words found co-occurring with both of the selected identity terms."
439
+ )
440
+
441
+
442
+ def npmi_show(paired_results):
443
+ if paired_results.empty:
444
+ st.markdown("No words that co-occur enough times for results! Or there's a 🐛.")
445
+ else:
446
+ s = pd.DataFrame(paired_results.sort_values(by="npmi-bias", ascending=True))
447
+ # s.columns=pd.MultiIndex.from_arrays([['npmi','npmi','npmi','count', 'count'],['bias','man','straight','man','straight']])
448
+ s.index.name = "word"
449
+ npmi_cols = s.filter(like="npmi").columns
450
+ count_cols = s.filter(like="count").columns
451
+ # TODO: This is very different look than the duplicates table above. Should probably standardize.
452
+ cm = sns.palplot(sns.diverging_palette(270, 36, s=99, l=48, n=16))
453
+ out_df = (
454
+ s.style.background_gradient(subset=npmi_cols, cmap=cm)
455
+ .format(subset=npmi_cols, formatter="{:,.3f}")
456
+ .format(subset=count_cols, formatter=int)
457
+ .set_properties(
458
+ subset=count_cols, **{"width": "10em", "text-align": "center"}
459
+ )
460
+ .set_properties(**{"align": "center"})
461
+ .set_caption(
462
+ "nPMI scores and co-occurence counts between the selected identity terms and the words they both co-occur with"
463
+ )
464
+ ) # s = pd.read_excel("output.xlsx", index_col="word")
465
+ st.write("### Here is your dataset's nPMI results:")
466
+ st.dataframe(out_df)
467
+
468
+
469
+ ### Dumping unused functions here for now
470
+ ### Second, show the distribution of text perplexities
471
+ def expander_text_perplexities(text_label_df, sorted_sents_loss, fig_loss):
472
+ with st.expander("Show text perplexities A", expanded=False):
473
+ st.markdown("### Text perplexities A")
474
+ st.plotly_chart(fig_loss, use_container_width=True)
475
+ start_id_show_loss = st.slider(
476
+ "Show highest perplexity sentences in A starting at index:",
477
+ 0,
478
+ text_label_df.shape[0] - 5,
479
+ value=0,
480
+ step=5,
481
+ )
482
+ for lss, sent in sorted_sents_loss[start_id_show_loss : start_id_show_loss + 5]:
483
+ st.text(f"{lss:.3f} {sent}")
data_measurements/zipf.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+ import powerlaw
20
+ import streamlit as st
21
+ from scipy.stats import ks_2samp
22
+ from scipy.stats import zipf as zipf_lib
23
+
24
+ from .dataset_utils import CNT, PROP
25
+
26
+ # treating inf values as NaN as well
27
+
28
+ pd.set_option("use_inf_as_na", True)
29
+
30
+ logs = logging.getLogger(__name__)
31
+ logs.setLevel(logging.INFO)
32
+ logs.propagate = False
33
+
34
+ if not logs.handlers:
35
+
36
+ # Logging info to log file
37
+ file = logging.FileHandler("./log_files/zipf.log")
38
+ fileformat = logging.Formatter("%(asctime)s:%(message)s")
39
+ file.setLevel(logging.INFO)
40
+ file.setFormatter(fileformat)
41
+
42
+ # Logging debug messages to stream
43
+ stream = logging.StreamHandler()
44
+ streamformat = logging.Formatter("[data_measurements_tool] %(message)s")
45
+ stream.setLevel(logging.WARNING)
46
+ stream.setFormatter(streamformat)
47
+
48
+ logs.addHandler(file)
49
+ logs.addHandler(stream)
50
+
51
+
52
+ class Zipf:
53
+ def __init__(self, vocab_counts_df=pd.DataFrame()):
54
+ self.vocab_counts_df = vocab_counts_df
55
+ self.alpha = None
56
+ self.xmin = None
57
+ self.xmax = None
58
+ self.fit = None
59
+ self.ranked_words = {}
60
+ self.uniq_counts = []
61
+ self.uniq_ranks = []
62
+ self.uniq_fit_counts = None
63
+ self.term_df = None
64
+ self.pvalue = None
65
+ self.ks_test = None
66
+ self.distance = None
67
+ self.fit = None
68
+ self.predicted_zipf_counts = None
69
+ if not self.vocab_counts_df.empty:
70
+ logs.info("Fitting based on input vocab counts.")
71
+ self.calc_fit(vocab_counts_df)
72
+ logs.info("Getting predicted counts.")
73
+ self.predicted_zipf_counts = self.calc_zipf_counts(vocab_counts_df)
74
+
75
+ def load(self, zipf_dict):
76
+ self.set_xmin(zipf_dict["xmin"])
77
+ self.set_xmax(zipf_dict["xmax"])
78
+ self.set_alpha(zipf_dict["alpha"])
79
+ self.set_ks_distance(zipf_dict["ks_distance"])
80
+ self.set_p(zipf_dict["p-value"])
81
+ self.set_unique_ranks(zipf_dict["uniq_ranks"])
82
+ self.set_unique_counts(zipf_dict["uniq_counts"])
83
+
84
+ def calc_fit(self, vocab_counts_df):
85
+ """
86
+ Uses the powerlaw package to fit the observed frequencies to a zipfian distribution.
87
+ We use the KS-distance to fit, as that seems more appropriate that MLE.
88
+ :param vocab_counts_df:
89
+ :return:
90
+ """
91
+ self.vocab_counts_df = vocab_counts_df
92
+ # TODO: These proportions may have already been calculated.
93
+ vocab_counts_df[PROP] = vocab_counts_df[CNT] / float(sum(vocab_counts_df[CNT]))
94
+ rank_column = vocab_counts_df[CNT].rank(
95
+ method="dense", numeric_only=True, ascending=False
96
+ )
97
+ vocab_counts_df["rank"] = rank_column.astype("int64")
98
+ observed_counts = vocab_counts_df[CNT].values
99
+ # Note another method for determining alpha might be defined by
100
+ # (Newman, 2005): alpha = 1 + n * sum(ln( xi / xmin )) ^ -1
101
+ self.fit = powerlaw.Fit(observed_counts, fit_method="KS", discrete=True)
102
+ # This should probably be a pmf (not pdf); using discrete=True above.
103
+ # original_data=False uses only the fitted data (within xmin and xmax).
104
+ # pdf_bin_edges: The portion of the data within the bin.
105
+ # observed_pdf: The probability density function (normalized histogram)
106
+ # of the data.
107
+ pdf_bin_edges, observed_pdf = self.fit.pdf(original_data=False)
108
+ # See the 'Distribution' class described here for info:
109
+ # https://pythonhosted.org/powerlaw/#powerlaw.Fit.pdf
110
+ theoretical_distro = self.fit.power_law
111
+ # The probability density function (normalized histogram) of the
112
+ # theoretical distribution.
113
+ predicted_pdf = theoretical_distro.pdf()
114
+ # !!!! CRITICAL VALUE FOR ZIPF !!!!
115
+ self.alpha = theoretical_distro.alpha
116
+ # Exclusive xmin: The optimal xmin *beyond which* the scaling regime of
117
+ # the power law fits best.
118
+ self.xmin = theoretical_distro.xmin
119
+ self.xmax = theoretical_distro.xmax
120
+ self.distance = theoretical_distro.KS()
121
+ self.ks_test = ks_2samp(observed_pdf, predicted_pdf)
122
+ self.pvalue = self.ks_test[1]
123
+ logs.info("KS test:")
124
+ logs.info(self.ks_test)
125
+
126
+ def set_xmax(self, xmax):
127
+ """
128
+ xmax is usually None, so we add some handling to set it as the
129
+ maximum rank in the dataset.
130
+ :param xmax:
131
+ :return:
132
+ """
133
+ if xmax:
134
+ self.xmax = int(xmax)
135
+ elif self.uniq_counts:
136
+ self.xmax = int(len(self.uniq_counts))
137
+ elif self.uniq_ranks:
138
+ self.xmax = int(len(self.uniq_ranks))
139
+
140
+ def get_xmax(self):
141
+ """
142
+ :return:
143
+ """
144
+ if not self.xmax:
145
+ self.set_xmax(self.xmax)
146
+ return self.xmax
147
+
148
+ def set_p(self, p):
149
+ self.p = int(p)
150
+
151
+ def get_p(self):
152
+ return int(self.p)
153
+
154
+ def set_xmin(self, xmin):
155
+ self.xmin = xmin
156
+
157
+ def get_xmin(self):
158
+ if self.xmin:
159
+ return int(self.xmin)
160
+ return self.xmin
161
+
162
+ def set_alpha(self, alpha):
163
+ self.alpha = float(alpha)
164
+
165
+ def get_alpha(self):
166
+ return float(self.alpha)
167
+
168
+ def set_ks_distance(self, distance):
169
+ self.distance = float(distance)
170
+
171
+ def get_ks_distance(self):
172
+ return self.distance
173
+
174
+ def calc_zipf_counts(self, vocab_counts_df):
175
+ """
176
+ The fit is based on an optimal xmin (minimum rank)
177
+ Let's use this to make count estimates for the zipf fit,
178
+ by multiplying the fitted pmf value by the sum of counts above xmin.
179
+ :return: array of count values following the fitted pmf.
180
+ """
181
+ # TODO: Limit from above xmin to below xmax, not just above xmin.
182
+ counts = vocab_counts_df[CNT]
183
+ self.uniq_counts = list(pd.unique(counts))
184
+ self.uniq_ranks = list(np.arange(1, len(self.uniq_counts) + 1))
185
+ logs.info(self.uniq_counts)
186
+ logs.info(self.xmin)
187
+ logs.info(self.xmax)
188
+ # Makes sure they are ints if not None
189
+ xmin = self.get_xmin()
190
+ xmax = self.get_xmax()
191
+ self.uniq_fit_counts = self.uniq_counts[xmin + 1 : xmax]
192
+ pmf_mass = float(sum(self.uniq_fit_counts))
193
+ zipf_counts = np.array(
194
+ [self.estimate_count(rank, pmf_mass) for rank in self.uniq_ranks]
195
+ )
196
+ return zipf_counts
197
+
198
+ def estimate_count(self, rank, pmf_mass):
199
+ return int(round(zipf_lib.pmf(rank, self.alpha) * pmf_mass))
200
+
201
+ def set_unique_ranks(self, ranks):
202
+ self.uniq_ranks = ranks
203
+
204
+ def get_unique_ranks(self):
205
+ return self.uniq_ranks
206
+
207
+ def get_unique_fit_counts(self):
208
+ return self.uniq_fit_counts
209
+
210
+ def set_unique_counts(self, counts):
211
+ self.uniq_counts = counts
212
+
213
+ def get_unique_counts(self):
214
+ return self.uniq_counts
215
+
216
+ def set_axes(self, unique_counts, unique_ranks):
217
+ self.uniq_counts = unique_counts
218
+ self.uniq_ranks = unique_ranks
219
+
220
+ # TODO: Incorporate this function (not currently using)
221
+ def fit_others(self, fit):
222
+ st.markdown(
223
+ "_Checking log likelihood ratio to see if the data is better explained by other well-behaved distributions..._"
224
+ )
225
+ # The first value returned from distribution_compare is the log likelihood ratio
226
+ better_distro = False
227
+ trunc = fit.distribution_compare("power_law", "truncated_power_law")
228
+ if trunc[0] < 0:
229
+ st.markdown("Seems a truncated power law is a better fit.")
230
+ better_distro = True
231
+
232
+ lognormal = fit.distribution_compare("power_law", "lognormal")
233
+ if lognormal[0] < 0:
234
+ st.markdown("Seems a lognormal distribution is a better fit.")
235
+ st.markdown("But don't panic -- that happens sometimes with language.")
236
+ better_distro = True
237
+
238
+ exponential = fit.distribution_compare("power_law", "exponential")
239
+ if exponential[0] < 0:
240
+ st.markdown("Seems an exponential distribution is a better fit. Panic.")
241
+ better_distro = True
242
+
243
+ if not better_distro:
244
+ st.markdown("\nSeems your data is best fit by a power law. Celebrate!!")