meg-huggingface commited on
Commit
2981bb2
1 Parent(s): 2947ee2

Be gone, you merge conflicting filegit rm data_measurements/dataset_statistics.py

Browse files
data_measurements/dataset_statistics.py DELETED
@@ -1,1313 +0,0 @@
1
- # Copyright 2021 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import json
16
- import logging
17
- import statistics
18
- import torch
19
- from os import mkdir
20
- from os.path import exists, isdir
21
- from os.path import join as pjoin
22
-
23
- import nltk
24
- import numpy as np
25
- import pandas as pd
26
- import plotly
27
- import plotly.express as px
28
- import plotly.figure_factory as ff
29
- import plotly.graph_objects as go
30
- import pyarrow.feather as feather
31
- import matplotlib.pyplot as plt
32
- import matplotlib.image as mpimg
33
- import seaborn as sns
34
- from datasets import load_from_disk
35
- from nltk.corpus import stopwords
36
- from sklearn.feature_extraction.text import CountVectorizer
37
-
38
- from .dataset_utils import (
39
- TOT_WORDS,
40
- TOT_OPEN_WORDS,
41
- CNT,
42
- DEDUP_TOT,
43
- EMBEDDING_FIELD,
44
- LENGTH_FIELD,
45
- OUR_LABEL_FIELD,
46
- OUR_TEXT_FIELD,
47
- PROP,
48
- TEXT_NAN_CNT,
49
- TOKENIZED_FIELD,
50
- TXT_LEN,
51
- VOCAB,
52
- WORD,
53
- extract_field,
54
- load_truncated_dataset,
55
- )
56
- from .embeddings import Embeddings
57
- from .npmi import nPMI
58
- from .zipf import Zipf
59
-
60
- pd.options.display.float_format = "{:,.3f}".format
61
-
62
- logs = logging.getLogger(__name__)
63
- logs.setLevel(logging.WARNING)
64
- logs.propagate = False
65
-
66
- if not logs.handlers:
67
-
68
- # Logging info to log file
69
- file = logging.FileHandler("./log_files/dataset_statistics.log")
70
- fileformat = logging.Formatter("%(asctime)s:%(message)s")
71
- file.setLevel(logging.INFO)
72
- file.setFormatter(fileformat)
73
-
74
- # Logging debug messages to stream
75
- stream = logging.StreamHandler()
76
- streamformat = logging.Formatter("[data_measurements_tool] %(message)s")
77
- stream.setLevel(logging.WARNING)
78
- stream.setFormatter(streamformat)
79
-
80
- logs.addHandler(file)
81
- logs.addHandler(stream)
82
-
83
-
84
- # TODO: Read this in depending on chosen language / expand beyond english
85
- nltk.download("stopwords")
86
- _CLOSED_CLASS = (
87
- stopwords.words("english")
88
- + [
89
- "t",
90
- "n",
91
- "ll",
92
- "d",
93
- "wasn",
94
- "weren",
95
- "won",
96
- "aren",
97
- "wouldn",
98
- "shouldn",
99
- "didn",
100
- "don",
101
- "hasn",
102
- "ain",
103
- "couldn",
104
- "doesn",
105
- "hadn",
106
- "haven",
107
- "isn",
108
- "mightn",
109
- "mustn",
110
- "needn",
111
- "shan",
112
- "would",
113
- "could",
114
- "dont",
115
- "u",
116
- ]
117
- + [str(i) for i in range(0, 21)]
118
- )
119
- _IDENTITY_TERMS = [
120
- "man",
121
- "woman",
122
- "non-binary",
123
- "gay",
124
- "lesbian",
125
- "queer",
126
- "trans",
127
- "straight",
128
- "cis",
129
- "she",
130
- "her",
131
- "hers",
132
- "he",
133
- "him",
134
- "his",
135
- "they",
136
- "them",
137
- "their",
138
- "theirs",
139
- "himself",
140
- "herself",
141
- ]
142
- # treating inf values as NaN as well
143
- pd.set_option("use_inf_as_na", True)
144
-
145
- _MIN_VOCAB_COUNT = 10
146
- _TREE_DEPTH = 12
147
- _TREE_MIN_NODES = 250
148
- # as long as we're using sklearn - already pushing the resources
149
- _MAX_CLUSTER_EXAMPLES = 5000
150
- _NUM_VOCAB_BATCHES = 2000
151
- _TOP_N = 100
152
- _CVEC = CountVectorizer(token_pattern="(?u)\\b\\w+\\b", lowercase=True)
153
-
154
- class DatasetStatisticsCacheClass:
155
- def __init__(
156
- self,
157
- cache_dir,
158
- dset_name,
159
- dset_config,
160
- split_name,
161
- text_field,
162
- label_field,
163
- label_names,
164
- calculation=None,
165
- use_cache=False,
166
- ):
167
- # This is only used for standalone runs for each kind of measurement.
168
- self.calculation = calculation
169
- self.our_text_field = OUR_TEXT_FIELD
170
- self.our_length_field = LENGTH_FIELD
171
- self.our_label_field = OUR_LABEL_FIELD
172
- self.our_tokenized_field = TOKENIZED_FIELD
173
- self.our_embedding_field = EMBEDDING_FIELD
174
- self.cache_dir = cache_dir
175
- # Use stored data if there; otherwise calculate afresh
176
- self.use_cache = use_cache
177
- ### What are we analyzing?
178
- # name of the Hugging Face dataset
179
- self.dset_name = dset_name
180
- # name of the dataset config
181
- self.dset_config = dset_config
182
- # name of the split to analyze
183
- self.split_name = split_name
184
- # TODO: Chould this be "feature" ?
185
- # which text fields are we analysing?
186
- self.text_field = text_field
187
- # which label fields are we analysing?
188
- self.label_field = label_field
189
- # what are the names of the classes?
190
- self.label_names = label_names
191
- ## Hugging Face dataset objects
192
- self.dset = None # original dataset
193
- # HF dataset with all of the self.text_field instances in self.dset
194
- self.text_dset = None
195
- self.dset_peek = None
196
- # HF dataset with text embeddings in the same order as self.text_dset
197
- self.embeddings_dset = None
198
- # HF dataset with all of the self.label_field instances in self.dset
199
- self.label_dset = None
200
- ## Data frames
201
- # Tokenized text
202
- self.tokenized_df = None
203
- # save sentence length histogram in the class so it doesn't ge re-computed
204
- self.length_df = None
205
- self.fig_tok_length = None
206
- # Data Frame version of self.label_dset
207
- self.label_df = None
208
- # save label pie chart in the class so it doesn't ge re-computed
209
- self.fig_labels = None
210
- # Vocabulary with word counts in the dataset
211
- self.vocab_counts_df = None
212
- # Vocabulary filtered to remove stopwords
213
- self.vocab_counts_filtered_df = None
214
- self.sorted_top_vocab_df = None
215
- ## General statistics and duplicates
216
- self.total_words = 0
217
- self.total_open_words = 0
218
- # Number of NaN values (NOT empty strings)
219
- self.text_nan_count = 0
220
- # Number of text items that appear more than once in the dataset
221
- self.dedup_total = 0
222
- # Duplicated text items along with their number of occurences ("count")
223
- self.dup_counts_df = None
224
- self.avg_length = None
225
- self.std_length = None
226
- self.general_stats_dict = None
227
- self.num_uniq_lengths = 0
228
- # clustering text by embeddings
229
- # the hierarchical clustering tree is represented as a list of nodes,
230
- # the first is the root
231
- self.node_list = []
232
- # save tree figure in the class so it doesn't ge re-computed
233
- self.fig_tree = None
234
- # keep Embeddings object around to explore clusters
235
- self.embeddings = None
236
- # nPMI
237
- # Holds a nPMIStatisticsCacheClass object
238
- self.npmi_stats = None
239
- # TODO: Have lowercase be an option for a user to set.
240
- self.to_lowercase = True
241
- # The minimum amount of times a word should occur to be included in
242
- # word-count-based calculations (currently just relevant to nPMI)
243
- self.min_vocab_count = _MIN_VOCAB_COUNT
244
- # zipf
245
- self.z = None
246
- self.zipf_fig = None
247
- self.cvec = _CVEC
248
- # File definitions
249
- # path to the directory used for caching
250
- if not isinstance(text_field, str):
251
- text_field = "-".join(text_field)
252
- #if isinstance(label_field, str):
253
- # label_field = label_field
254
- #else:
255
- # label_field = "-".join(label_field)
256
- self.cache_path = pjoin(
257
- self.cache_dir,
258
- f"{dset_name}_{dset_config}_{split_name}_{text_field}", #{label_field},
259
- )
260
- if not isdir(self.cache_path):
261
- logs.warning("Creating cache directory %s." % self.cache_path)
262
- mkdir(self.cache_path)
263
-
264
- # Cache files not needed for UI
265
- self.dset_fid = pjoin(self.cache_path, "base_dset")
266
- self.tokenized_df_fid = pjoin(self.cache_path, "tokenized_df.feather")
267
- self.label_dset_fid = pjoin(self.cache_path, "label_dset")
268
-
269
- # Needed for UI -- embeddings
270
- self.text_dset_fid = pjoin(self.cache_path, "text_dset")
271
- # Needed for UI
272
- self.dset_peek_json_fid = pjoin(self.cache_path, "dset_peek.json")
273
-
274
- ## Label cache files.
275
- # Needed for UI
276
- self.fig_labels_json_fid = pjoin(self.cache_path, "fig_labels.json")
277
-
278
- ## Length cache files
279
- # Needed for UI
280
- self.length_df_fid = pjoin(self.cache_path, "length_df.feather")
281
- # Needed for UI
282
- self.length_stats_json_fid = pjoin(self.cache_path, "length_stats.json")
283
- self.vocab_counts_df_fid = pjoin(self.cache_path, "vocab_counts.feather")
284
- # Needed for UI
285
- self.dup_counts_df_fid = pjoin(self.cache_path, "dup_counts_df.feather")
286
- # Needed for UI
287
- self.fig_tok_length_fid = pjoin(self.cache_path, "fig_tok_length.json")
288
-
289
- ## General text stats
290
- # Needed for UI
291
- self.general_stats_json_fid = pjoin(self.cache_path, "general_stats_dict.json")
292
- # Needed for UI
293
- self.sorted_top_vocab_df_fid = pjoin(self.cache_path,
294
- "sorted_top_vocab.feather")
295
- ## Zipf cache files
296
- # Needed for UI
297
- self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
298
- # Needed for UI
299
- self.zipf_fig_fid = pjoin(self.cache_path, "zipf_fig.json")
300
-
301
- ## Embeddings cache files
302
- # Needed for UI
303
- self.node_list_fid = pjoin(self.cache_path, "node_list.th")
304
- # Needed for UI
305
- self.fig_tree_json_fid = pjoin(self.cache_path, "fig_tree.json")
306
- self.zipf_counts = None
307
-
308
- self.live = False
309
-
310
- def set_deployment(self, live=True):
311
- """
312
- Function that we can hit when we deploy, so that cache files are not
313
- written out/recalculated, but instead that part of the UI can be punted.
314
- """
315
- self.live = live
316
-
317
- def get_base_dataset(self):
318
- """Gets a pointer to the truncated base dataset object."""
319
- if not self.dset:
320
- self.dset = load_truncated_dataset(
321
- self.dset_name,
322
- self.dset_config,
323
- self.split_name,
324
- cache_name=self.dset_fid,
325
- use_cache=True,
326
- use_streaming=True,
327
- )
328
-
329
- def load_or_prepare_general_stats(self, save=True):
330
- """
331
- Content for expander_general_stats widget.
332
- Provides statistics for total words, total open words,
333
- the sorted top vocab, the NaN count, and the duplicate count.
334
- Args:
335
-
336
- Returns:
337
-
338
- """
339
- # General statistics
340
- if (
341
- self.use_cache
342
- and exists(self.general_stats_json_fid)
343
- and exists(self.dup_counts_df_fid)
344
- and exists(self.sorted_top_vocab_df_fid)
345
- ):
346
- logs.info('Loading cached general stats')
347
- self.load_general_stats()
348
- else:
349
- if not self.live:
350
- logs.info('Preparing general stats')
351
- self.prepare_general_stats()
352
- if save:
353
- write_df(self.sorted_top_vocab_df, self.sorted_top_vocab_df_fid)
354
- write_df(self.dup_counts_df, self.dup_counts_df_fid)
355
- write_json(self.general_stats_dict, self.general_stats_json_fid)
356
-
357
-
358
- def load_or_prepare_text_lengths(self, save=True):
359
- """
360
- The text length widget relies on this function, which provides
361
- a figure of the text lengths, some text length statistics, and
362
- a text length dataframe to peruse.
363
- Args:
364
- save:
365
- Returns:
366
-
367
- """
368
- # Text length figure
369
- if (self.use_cache and exists(self.fig_tok_length_fid)):
370
- self.fig_tok_length_png = mpimg.imread(self.fig_tok_length_fid)
371
- self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
372
- else:
373
- if not self.live:
374
- self.prepare_fig_text_lengths()
375
- if save:
376
- write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
377
-
378
- # Text length dataframe
379
- if self.use_cache and exists(self.length_df_fid):
380
- self.length_df = feather.read_feather(self.length_df_fid)
381
- else:
382
- if not self.live:
383
- self.prepare_length_df()
384
- if save:
385
- write_df(self.length_df, self.length_df_fid)
386
-
387
- # Text length stats.
388
- if self.use_cache and exists(self.length_stats_json_fid):
389
- with open(self.length_stats_json_fid, "r") as f:
390
- self.length_stats_dict = json.load(f)
391
- self.avg_length = self.length_stats_dict["avg length"]
392
- self.std_length = self.length_stats_dict["std length"]
393
- self.num_uniq_lengths = self.length_stats_dict["num lengths"]
394
- else:
395
- if not self.live:
396
- self.prepare_text_length_stats()
397
- if save:
398
- write_json(self.length_stats_dict, self.length_stats_json_fid)
399
-
400
- def prepare_length_df(self):
401
- if not self.live:
402
- if self.tokenized_df is None:
403
- self.tokenized_df = self.do_tokenization()
404
- self.tokenized_df[LENGTH_FIELD] = self.tokenized_df[
405
- TOKENIZED_FIELD].apply(len)
406
- self.length_df = self.tokenized_df[
407
- [LENGTH_FIELD, OUR_TEXT_FIELD]].sort_values(
408
- by=[LENGTH_FIELD], ascending=True
409
- )
410
-
411
- def prepare_text_length_stats(self):
412
- if not self.live:
413
- if self.tokenized_df is None or LENGTH_FIELD not in self.tokenized_df.columns or self.length_df is None:
414
- self.prepare_length_df()
415
- avg_length = sum(self.tokenized_df[LENGTH_FIELD])/len(self.tokenized_df[LENGTH_FIELD])
416
- self.avg_length = round(avg_length, 1)
417
- std_length = statistics.stdev(self.tokenized_df[LENGTH_FIELD])
418
- self.std_length = round(std_length, 1)
419
- self.num_uniq_lengths = len(self.length_df["length"].unique())
420
- self.length_stats_dict = {"avg length": self.avg_length,
421
- "std length": self.std_length,
422
- "num lengths": self.num_uniq_lengths}
423
-
424
- def prepare_fig_text_lengths(self):
425
- if not self.live:
426
- if self.tokenized_df is None or LENGTH_FIELD not in self.tokenized_df.columns:
427
- self.prepare_length_df()
428
- self.fig_tok_length = make_fig_lengths(self.tokenized_df, LENGTH_FIELD)
429
-
430
- def load_or_prepare_embeddings(self, save=True):
431
- if self.use_cache and exists(self.node_list_fid) and exists(self.fig_tree_json_fid):
432
- self.node_list = torch.load(self.node_list_fid)
433
- self.fig_tree = read_plotly(self.fig_tree_json_fid)
434
- elif self.use_cache and exists(self.node_list_fid):
435
- self.node_list = torch.load(self.node_list_fid)
436
- self.fig_tree = make_tree_plot(self.node_list,
437
- self.text_dset)
438
- if save:
439
- write_plotly(self.fig_tree, self.fig_tree_json_fid)
440
- else:
441
- self.embeddings = Embeddings(self, use_cache=self.use_cache)
442
- self.embeddings.make_hierarchical_clustering()
443
- self.node_list = self.embeddings.node_list
444
- self.fig_tree = make_tree_plot(self.node_list,
445
- self.text_dset)
446
- if save:
447
- torch.save(self.node_list, self.node_list_fid)
448
- write_plotly(self.fig_tree, self.fig_tree_json_fid)
449
-
450
- # get vocab with word counts
451
- def load_or_prepare_vocab(self, save=True):
452
- """
453
- Calculates the vocabulary count from the tokenized text.
454
- The resulting dataframes may be used in nPMI calculations, zipf, etc.
455
- :param
456
- :return:
457
- """
458
- if (
459
- self.use_cache
460
- and exists(self.vocab_counts_df_fid)
461
- ):
462
- logs.info("Reading vocab from cache")
463
- self.load_vocab()
464
- self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
465
- else:
466
- logs.info("Calculating vocab afresh")
467
- if len(self.tokenized_df) == 0:
468
- self.tokenized_df = self.do_tokenization()
469
- if save:
470
- logs.info("Writing out.")
471
- write_df(self.tokenized_df, self.tokenized_df_fid)
472
- word_count_df = count_vocab_frequencies(self.tokenized_df)
473
- logs.info("Making dfs with proportion.")
474
- self.vocab_counts_df = calc_p_word(word_count_df)
475
- self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
476
- if save:
477
- logs.info("Writing out.")
478
- write_df(self.vocab_counts_df, self.vocab_counts_df_fid)
479
- logs.info("unfiltered vocab")
480
- logs.info(self.vocab_counts_df)
481
- logs.info("filtered vocab")
482
- logs.info(self.vocab_counts_filtered_df)
483
-
484
- def load_vocab(self):
485
- with open(self.vocab_counts_df_fid, "rb") as f:
486
- self.vocab_counts_df = feather.read_feather(f)
487
- # Handling for changes in how the index is saved.
488
- self.vocab_counts_df = self._set_idx_col_names(self.vocab_counts_df)
489
-
490
- def load_or_prepare_text_duplicates(self, save=True):
491
- if self.use_cache and exists(self.dup_counts_df_fid):
492
- with open(self.dup_counts_df_fid, "rb") as f:
493
- self.dup_counts_df = feather.read_feather(f)
494
- elif self.dup_counts_df is None:
495
- if not self.live:
496
- self.prepare_text_duplicates()
497
- if save:
498
- write_df(self.dup_counts_df, self.dup_counts_df_fid)
499
- else:
500
- if not self.live:
501
- # This happens when self.dup_counts_df is already defined;
502
- # This happens when general_statistics were calculated first,
503
- # since general statistics requires the number of duplicates
504
- if save:
505
- write_df(self.dup_counts_df, self.dup_counts_df_fid)
506
-
507
- def load_general_stats(self):
508
- self.general_stats_dict = json.load(open(self.general_stats_json_fid, encoding="utf-8"))
509
- with open(self.sorted_top_vocab_df_fid, "rb") as f:
510
- self.sorted_top_vocab_df = feather.read_feather(f)
511
- self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
512
- self.dedup_total = self.general_stats_dict[DEDUP_TOT]
513
- self.total_words = self.general_stats_dict[TOT_WORDS]
514
- self.total_open_words = self.general_stats_dict[TOT_OPEN_WORDS]
515
-
516
- def prepare_general_stats(self):
517
- if not self.live:
518
- if self.tokenized_df is None:
519
- logs.warning("Tokenized dataset not yet loaded; doing so.")
520
- self.load_or_prepare_dataset()
521
- if self.vocab_counts_df is None:
522
- logs.warning("Vocab not yet loaded; doing so.")
523
- self.load_or_prepare_vocab()
524
- self.sorted_top_vocab_df = self.vocab_counts_filtered_df.sort_values(
525
- "count", ascending=False
526
- ).head(_TOP_N)
527
- self.total_words = len(self.vocab_counts_df)
528
- self.total_open_words = len(self.vocab_counts_filtered_df)
529
- self.text_nan_count = int(self.tokenized_df.isnull().sum().sum())
530
- self.prepare_text_duplicates()
531
- self.dedup_total = sum(self.dup_counts_df[CNT])
532
- self.general_stats_dict = {
533
- TOT_WORDS: self.total_words,
534
- TOT_OPEN_WORDS: self.total_open_words,
535
- TEXT_NAN_CNT: self.text_nan_count,
536
- DEDUP_TOT: self.dedup_total,
537
- }
538
-
539
- def prepare_text_duplicates(self):
540
- if not self.live:
541
- if self.tokenized_df is None:
542
- self.load_or_prepare_tokenized_df()
543
- dup_df = self.tokenized_df[
544
- self.tokenized_df.duplicated([OUR_TEXT_FIELD])]
545
- self.dup_counts_df = pd.DataFrame(
546
- dup_df.pivot_table(
547
- columns=[OUR_TEXT_FIELD], aggfunc="size"
548
- ).sort_values(ascending=False),
549
- columns=[CNT],
550
- )
551
- self.dup_counts_df[OUR_TEXT_FIELD] = self.dup_counts_df.index.copy()
552
-
553
- def load_or_prepare_dataset(self, save=True):
554
- """
555
- Prepares the HF datasets and data frames containing the untokenized and
556
- tokenized text as well as the label values.
557
- self.tokenized_df is used further for calculating text lengths,
558
- word counts, etc.
559
- Args:
560
- save: Store the calculated data to disk.
561
-
562
- Returns:
563
-
564
- """
565
- logs.info("Doing text dset.")
566
- self.load_or_prepare_text_dset(save)
567
- logs.info("Doing tokenized dataframe")
568
- self.load_or_prepare_tokenized_df(save)
569
- logs.info("Doing dataset peek")
570
- self.load_or_prepare_dset_peek(save)
571
-
572
- def load_or_prepare_dset_peek(self, save=True):
573
- if self.use_cache and exists(self.dset_peek_json_fid):
574
- with open(self.dset_peek_json_fid, "r") as f:
575
- self.dset_peek = json.load(f)["dset peek"]
576
- else:
577
- if self.dset is None:
578
- self.get_base_dataset()
579
- self.dset_peek = self.dset[:100]
580
- if save:
581
- write_json({"dset peek": self.dset_peek}, self.dset_peek_json_fid)
582
-
583
- def load_or_prepare_tokenized_df(self, save=True):
584
- if (self.use_cache and exists(self.tokenized_df_fid)):
585
- self.tokenized_df = feather.read_feather(self.tokenized_df_fid)
586
- else:
587
- if not self.live:
588
- # tokenize all text instances
589
- self.tokenized_df = self.do_tokenization()
590
- if save:
591
- logs.warning("Saving tokenized dataset to disk")
592
- # save tokenized text
593
- write_df(self.tokenized_df, self.tokenized_df_fid)
594
-
595
- def load_or_prepare_text_dset(self, save=True):
596
- if (self.use_cache and exists(self.text_dset_fid)):
597
- # load extracted text
598
- self.text_dset = load_from_disk(self.text_dset_fid)
599
- logs.warning("Loaded dataset from disk")
600
- logs.info(self.text_dset)
601
- # ...Or load it from the server and store it anew
602
- else:
603
- if not self.live:
604
- self.prepare_text_dset()
605
- if save:
606
- # save extracted text instances
607
- logs.warning("Saving dataset to disk")
608
- self.text_dset.save_to_disk(self.text_dset_fid)
609
-
610
- def prepare_text_dset(self):
611
- if not self.live:
612
- self.get_base_dataset()
613
- # extract all text instances
614
- self.text_dset = self.dset.map(
615
- lambda examples: extract_field(
616
- examples, self.text_field, OUR_TEXT_FIELD
617
- ),
618
- batched=True,
619
- remove_columns=list(self.dset.features),
620
- )
621
-
622
- def do_tokenization(self):
623
- """
624
- Tokenizes the dataset
625
- :return:
626
- """
627
- if self.text_dset is None:
628
- self.load_or_prepare_text_dset()
629
- sent_tokenizer = self.cvec.build_tokenizer()
630
-
631
- def tokenize_batch(examples):
632
- # TODO: lowercase should be an option
633
- res = {
634
- TOKENIZED_FIELD: [
635
- tuple(sent_tokenizer(text.lower()))
636
- for text in examples[OUR_TEXT_FIELD]
637
- ]
638
- }
639
- res[LENGTH_FIELD] = [len(tok_text) for tok_text in res[TOKENIZED_FIELD]]
640
- return res
641
-
642
- tokenized_dset = self.text_dset.map(
643
- tokenize_batch,
644
- batched=True,
645
- # remove_columns=[OUR_TEXT_FIELD], keep around to print
646
- )
647
- tokenized_df = pd.DataFrame(tokenized_dset)
648
- return tokenized_df
649
-
650
- def set_label_field(self, label_field="label"):
651
- """
652
- Setter for label_field. Used in the CLI when a user asks for information
653
- about labels, but does not specify the field;
654
- 'label' is assumed as a default.
655
- """
656
- self.label_field = label_field
657
-
658
- def load_or_prepare_labels(self, save=True):
659
- # TODO: This is in a transitory state for creating fig cache.
660
- # Clean up to be caching and reading everything correctly.
661
- """
662
- Extracts labels from the Dataset
663
- :return:
664
- """
665
- # extracted labels
666
- if len(self.label_field) > 0:
667
- if self.use_cache and exists(self.fig_labels_json_fid):
668
- self.fig_labels = read_plotly(self.fig_labels_json_fid)
669
- elif self.use_cache and exists(self.label_dset_fid):
670
- # load extracted labels
671
- self.label_dset = load_from_disk(self.label_dset_fid)
672
- self.label_df = self.label_dset.to_pandas()
673
- self.fig_labels = make_fig_labels(
674
- self.label_df, self.label_names, OUR_LABEL_FIELD
675
- )
676
- if save:
677
- write_plotly(self.fig_labels, self.fig_labels_json_fid)
678
- else:
679
- if not self.live:
680
- self.prepare_labels()
681
- if save:
682
- # save extracted label instances
683
- self.label_dset.save_to_disk(self.label_dset_fid)
684
- write_plotly(self.fig_labels, self.fig_labels_json_fid)
685
-
686
- def prepare_labels(self):
687
- if not self.live:
688
- self.get_base_dataset()
689
- self.label_dset = self.dset.map(
690
- lambda examples: extract_field(
691
- examples, self.label_field, OUR_LABEL_FIELD
692
- ),
693
- batched=True,
694
- remove_columns=list(self.dset.features),
695
- )
696
- self.label_df = self.label_dset.to_pandas()
697
- self.fig_labels = make_fig_labels(
698
- self.label_df, self.label_names, OUR_LABEL_FIELD
699
- )
700
-
701
- def load_or_prepare_npmi(self):
702
- self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
703
- self.npmi_stats.load_or_prepare_npmi_terms()
704
-
705
- def load_or_prepare_zipf(self, save=True):
706
- # TODO: Current UI only uses the fig, meaning the self.z here is irrelevant
707
- # when only reading from cache. Either the UI should use it, or it should
708
- # be removed when reading in cache
709
- if self.use_cache and exists(self.zipf_fig_fid) and exists(self.zipf_fid):
710
- with open(self.zipf_fid, "r") as f:
711
- zipf_dict = json.load(f)
712
- self.z = Zipf()
713
- self.z.load(zipf_dict)
714
- # TODO: Should this be cached?
715
- self.zipf_counts = self.z.calc_zipf_counts(self.vocab_counts_df)
716
- self.zipf_fig = read_plotly(self.zipf_fig_fid)
717
- elif self.use_cache and exists(self.zipf_fid):
718
- # TODO: Read zipf data so that the vocab is there.
719
- with open(self.zipf_fid, "r") as f:
720
- zipf_dict = json.load(f)
721
- self.z = Zipf()
722
- self.z.load(zipf_dict)
723
- self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
724
- if save:
725
- write_plotly(self.zipf_fig, self.zipf_fig_fid)
726
- else:
727
- self.z = Zipf(self.vocab_counts_df)
728
- self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
729
- if save:
730
- write_zipf_data(self.z, self.zipf_fid)
731
- write_plotly(self.zipf_fig, self.zipf_fig_fid)
732
-
733
- def _set_idx_col_names(self, input_vocab_df):
734
- if input_vocab_df.index.name != VOCAB and VOCAB in input_vocab_df.columns:
735
- input_vocab_df = input_vocab_df.set_index([VOCAB])
736
- input_vocab_df[VOCAB] = input_vocab_df.index
737
- return input_vocab_df
738
-
739
-
740
- class nPMIStatisticsCacheClass:
741
- """ "Class to interface between the app and the nPMI class
742
- by calling the nPMI class with the user's selections."""
743
-
744
- def __init__(self, dataset_stats, use_cache=False):
745
- self.live = dataset_stats.live
746
- self.dstats = dataset_stats
747
- self.pmi_cache_path = pjoin(self.dstats.cache_path, "pmi_files")
748
- if not isdir(self.pmi_cache_path):
749
- logs.warning("Creating pmi cache directory %s." % self.pmi_cache_path)
750
- # We need to preprocess everything.
751
- mkdir(self.pmi_cache_path)
752
- self.joint_npmi_df_dict = {}
753
- # TODO: Users ideally can type in whatever words they want.
754
- self.termlist = _IDENTITY_TERMS
755
- # termlist terms that are available more than _MIN_VOCAB_COUNT times
756
- self.available_terms = _IDENTITY_TERMS
757
- logs.info(self.termlist)
758
- self.use_cache = use_cache
759
- # TODO: Let users specify
760
- self.open_class_only = True
761
- self.min_vocab_count = self.dstats.min_vocab_count
762
- self.subgroup_files = {}
763
- self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
764
-
765
- def load_or_prepare_npmi_terms(self):
766
- """
767
- Figures out what identity terms the user can select, based on whether
768
- they occur more than self.min_vocab_count times
769
- :return: Identity terms occurring at least self.min_vocab_count times.
770
- """
771
- # TODO: Add the user's ability to select subgroups.
772
- # TODO: Make min_vocab_count here value selectable by the user.
773
- if (
774
- self.use_cache
775
- and exists(self.npmi_terms_fid)
776
- and json.load(open(self.npmi_terms_fid))["available terms"] != []
777
- ):
778
- self.available_terms = json.load(open(self.npmi_terms_fid))["available terms"]
779
- else:
780
- if not self.live:
781
- if self.dstats.vocab_counts_df is None:
782
- self.dstats.load_or_prepare_vocab()
783
-
784
- true_false = [
785
- term in self.dstats.vocab_counts_df.index for term in self.termlist
786
- ]
787
- word_list_tmp = [x for x, y in zip(self.termlist, true_false) if y]
788
- true_false_counts = [
789
- self.dstats.vocab_counts_df.loc[word, CNT] >= self.min_vocab_count
790
- for word in word_list_tmp
791
- ]
792
- available_terms = [
793
- word for word, y in zip(word_list_tmp, true_false_counts) if y
794
- ]
795
- logs.info(available_terms)
796
- with open(self.npmi_terms_fid, "w+") as f:
797
- json.dump({"available terms": available_terms}, f)
798
- self.available_terms = available_terms
799
- return self.available_terms
800
-
801
- def load_or_prepare_joint_npmi(self, subgroup_pair, save=True):
802
- """
803
- Run on-the fly, while the app is already open,
804
- as it depends on the subgroup terms that the user chooses
805
- :param subgroup_pair:
806
- :return:
807
- """
808
- # Canonical ordering for subgroup_list
809
- subgroup_pair = sorted(subgroup_pair)
810
- subgroup1 = subgroup_pair[0]
811
- subgroup2 = subgroup_pair[1]
812
- subgroups_str = "-".join(subgroup_pair)
813
- if not isdir(self.pmi_cache_path):
814
- logs.warning("Creating cache")
815
- # We need to preprocess everything.
816
- # This should eventually all go into a prepare_dataset CLI
817
- mkdir(self.pmi_cache_path)
818
- joint_npmi_fid = pjoin(self.pmi_cache_path, subgroups_str + "_npmi.csv")
819
- subgroup_files = define_subgroup_files(subgroup_pair, self.pmi_cache_path)
820
- # Defines the filenames for the cache files from the selected subgroups.
821
- # Get as much precomputed data as we can.
822
- if self.use_cache and exists(joint_npmi_fid):
823
- # When everything is already computed for the selected subgroups.
824
- logs.info("Loading cached joint npmi")
825
- joint_npmi_df = self.load_joint_npmi_df(joint_npmi_fid)
826
- npmi_display_cols = ['npmi-bias', subgroup1 + '-npmi', subgroup2 + '-npmi', subgroup1 + '-count', subgroup2 + '-count']
827
- joint_npmi_df = joint_npmi_df[npmi_display_cols]
828
- # When maybe some things have been computed for the selected subgroups.
829
- else:
830
- if not self.live:
831
- logs.info("Preparing new joint npmi")
832
- joint_npmi_df, subgroup_dict = self.prepare_joint_npmi_df(
833
- subgroup_pair, subgroup_files
834
- )
835
- if save:
836
- if joint_npmi_df is not None:
837
- # Cache new results
838
- logs.info("Writing out.")
839
- for subgroup in subgroup_pair:
840
- write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files)
841
- with open(joint_npmi_fid, "w+") as f:
842
- joint_npmi_df.to_csv(f)
843
- else:
844
- joint_npmi_df = pd.DataFrame()
845
- logs.info("The joint npmi df is")
846
- logs.info(joint_npmi_df)
847
- return joint_npmi_df
848
-
849
- def load_joint_npmi_df(self, joint_npmi_fid):
850
- """
851
- Reads in a saved dataframe with all of the paired results.
852
- :param joint_npmi_fid:
853
- :return: paired results
854
- """
855
- with open(joint_npmi_fid, "rb") as f:
856
- joint_npmi_df = pd.read_csv(f)
857
- joint_npmi_df = self._set_idx_cols_from_cache(joint_npmi_df)
858
- return joint_npmi_df.dropna()
859
-
860
- def prepare_joint_npmi_df(self, subgroup_pair, subgroup_files):
861
- """
862
- Computs the npmi bias based on the given subgroups.
863
- Handles cases where some of the selected subgroups have cached nPMI
864
- computations, but other's don't, computing everything afresh if there
865
- are not cached files.
866
- :param subgroup_pair:
867
- :return: Dataframe with nPMI for the words, nPMI bias between the words.
868
- """
869
- subgroup_dict = {}
870
- # When npmi is computed for some (but not all) of subgroup_list
871
- for subgroup in subgroup_pair:
872
- logs.info("Load or failing...")
873
- # When subgroup npmi has been computed in a prior session.
874
- cached_results = self.load_or_fail_cached_npmi_scores(
875
- subgroup, subgroup_files[subgroup]
876
- )
877
- # If the function did not return False and we did find it, use.
878
- if cached_results:
879
- # FYI: subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df = cached_results
880
- # Holds the previous sessions' data for use in this session.
881
- subgroup_dict[subgroup] = cached_results
882
- logs.info("Calculating for subgroup list")
883
- joint_npmi_df, subgroup_dict = self.do_npmi(subgroup_pair, subgroup_dict)
884
- return joint_npmi_df, subgroup_dict
885
-
886
- # TODO: Update pairwise assumption
887
- def do_npmi(self, subgroup_pair, subgroup_dict):
888
- """
889
- Calculates nPMI for given identity terms and the nPMI bias between.
890
- :param subgroup_pair: List of identity terms to calculate the bias for
891
- :return: Subset of data for the UI
892
- :return: Selected identity term's co-occurrence counts with
893
- other words, pmi per word, and nPMI per word.
894
- """
895
- no_results = False
896
- logs.info("Initializing npmi class")
897
- npmi_obj = self.set_npmi_obj()
898
- # Canonical ordering used
899
- subgroup_pair = tuple(sorted(subgroup_pair))
900
- # Calculating nPMI statistics
901
- for subgroup in subgroup_pair:
902
- # If the subgroup data is already computed, grab it.
903
- # TODO: Should we set idx and column names similarly to
904
- # how we set them for cached files?
905
- if subgroup not in subgroup_dict:
906
- logs.info("Calculating statistics for %s" % subgroup)
907
- vocab_cooc_df, pmi_df, npmi_df = npmi_obj.calc_metrics(subgroup)
908
- if vocab_cooc_df is None:
909
- no_results = True
910
- else:
911
- # Store the nPMI information for the current subgroups
912
- subgroup_dict[subgroup] = (vocab_cooc_df, pmi_df, npmi_df)
913
- if no_results:
914
- logs.warning("Couldn't grap the npmi files -- Under construction")
915
- return None, None
916
- else:
917
- # Pair the subgroups together, indexed by all words that
918
- # co-occur between them.
919
- logs.info("Computing pairwise npmi bias")
920
- paired_results = npmi_obj.calc_paired_metrics(subgroup_pair, subgroup_dict)
921
- UI_results = make_npmi_fig(paired_results, subgroup_pair)
922
- return UI_results.dropna(), subgroup_dict
923
-
924
- def set_npmi_obj(self):
925
- """
926
- Initializes the nPMI class with the given words and tokenized sentences.
927
- :return:
928
- """
929
- npmi_obj = nPMI(self.dstats.vocab_counts_df, self.dstats.tokenized_df)
930
- return npmi_obj
931
-
932
- def load_or_fail_cached_npmi_scores(self, subgroup, subgroup_fids):
933
- """
934
- Reads cached scores from the specified subgroup files
935
- :param subgroup: string of the selected identity term
936
- :return:
937
- """
938
- # TODO: Ordering of npmi, pmi, vocab triple should be consistent
939
- subgroup_npmi_fid, subgroup_pmi_fid, subgroup_cooc_fid = subgroup_fids
940
- if (
941
- exists(subgroup_npmi_fid)
942
- and exists(subgroup_pmi_fid)
943
- and exists(subgroup_cooc_fid)
944
- ):
945
- logs.info("Reading in pmi data....")
946
- with open(subgroup_cooc_fid, "rb") as f:
947
- subgroup_cooc_df = pd.read_csv(f)
948
- logs.info("pmi")
949
- with open(subgroup_pmi_fid, "rb") as f:
950
- subgroup_pmi_df = pd.read_csv(f)
951
- logs.info("npmi")
952
- with open(subgroup_npmi_fid, "rb") as f:
953
- subgroup_npmi_df = pd.read_csv(f)
954
- subgroup_cooc_df = self._set_idx_cols_from_cache(
955
- subgroup_cooc_df, subgroup, "count"
956
- )
957
- subgroup_pmi_df = self._set_idx_cols_from_cache(
958
- subgroup_pmi_df, subgroup, "pmi"
959
- )
960
- subgroup_npmi_df = self._set_idx_cols_from_cache(
961
- subgroup_npmi_df, subgroup, "npmi"
962
- )
963
- return subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df
964
- return False
965
-
966
- def _set_idx_cols_from_cache(self, csv_df, subgroup=None, calc_str=None):
967
- """
968
- Helps make sure all of the read-in files can be accessed within code
969
- via standardized indices and column names.
970
- :param csv_df:
971
- :param subgroup:
972
- :param calc_str:
973
- :return:
974
- """
975
- # The csv saves with this column instead of the index, so that's weird.
976
- if "Unnamed: 0" in csv_df.columns:
977
- csv_df = csv_df.set_index("Unnamed: 0")
978
- csv_df.index.name = WORD
979
- elif WORD in csv_df.columns:
980
- csv_df = csv_df.set_index(WORD)
981
- csv_df.index.name = WORD
982
- elif VOCAB in csv_df.columns:
983
- csv_df = csv_df.set_index(VOCAB)
984
- csv_df.index.name = WORD
985
- if subgroup and calc_str:
986
- csv_df.columns = [subgroup + "-" + calc_str]
987
- elif subgroup:
988
- csv_df.columns = [subgroup]
989
- elif calc_str:
990
- csv_df.columns = [calc_str]
991
- return csv_df
992
-
993
- def get_available_terms(self):
994
- return self.load_or_prepare_npmi_terms()
995
-
996
- def dummy(doc):
997
- return doc
998
-
999
- def count_vocab_frequencies(tokenized_df):
1000
- """
1001
- Based on an input pandas DataFrame with a 'text' column,
1002
- this function will count the occurrences of all words.
1003
- :return: [num_words x num_sentences] DataFrame with the rows corresponding to the
1004
- different vocabulary words and the column to the presence (0 or 1) of that word.
1005
- """
1006
-
1007
- cvec = CountVectorizer(
1008
- tokenizer=dummy,
1009
- preprocessor=dummy,
1010
- )
1011
- # We do this to calculate per-word statistics
1012
- # Fast calculation of single word counts
1013
- logs.info("Fitting dummy tokenization to make matrix using the previous tokenization")
1014
- cvec.fit(tokenized_df[TOKENIZED_FIELD])
1015
- document_matrix = cvec.transform(tokenized_df[TOKENIZED_FIELD])
1016
- batches = np.linspace(0, tokenized_df.shape[0], _NUM_VOCAB_BATCHES).astype(int)
1017
- i = 0
1018
- tf = []
1019
- while i < len(batches) - 1:
1020
- logs.info("%s of %s vocab batches" % (str(i), str(len(batches))))
1021
- batch_result = np.sum(
1022
- document_matrix[batches[i] : batches[i + 1]].toarray(), axis=0
1023
- )
1024
- tf.append(batch_result)
1025
- i += 1
1026
- word_count_df = pd.DataFrame(
1027
- [np.sum(tf, axis=0)], columns=cvec.get_feature_names()
1028
- ).transpose()
1029
- # Now organize everything into the dataframes
1030
- word_count_df.columns = [CNT]
1031
- word_count_df.index.name = WORD
1032
- return word_count_df
1033
-
1034
- def calc_p_word(word_count_df):
1035
- # p(word)
1036
- word_count_df[PROP] = word_count_df[CNT] / float(sum(word_count_df[CNT]))
1037
- vocab_counts_df = pd.DataFrame(word_count_df.sort_values(by=CNT, ascending=False))
1038
- vocab_counts_df[VOCAB] = vocab_counts_df.index
1039
- return vocab_counts_df
1040
-
1041
-
1042
- def filter_vocab(vocab_counts_df):
1043
- # TODO: Add warnings (which words are missing) to log file?
1044
- filtered_vocab_counts_df = vocab_counts_df.drop(_CLOSED_CLASS,
1045
- errors="ignore")
1046
- filtered_count = filtered_vocab_counts_df[CNT]
1047
- filtered_count_denom = float(sum(filtered_vocab_counts_df[CNT]))
1048
- filtered_vocab_counts_df[PROP] = filtered_count / filtered_count_denom
1049
- return filtered_vocab_counts_df
1050
-
1051
-
1052
- ## Figures ##
1053
-
1054
- def write_plotly(fig, fid):
1055
- write_json(plotly.io.to_json(fig), fid)
1056
-
1057
- def read_plotly(fid):
1058
- fig = plotly.io.from_json(json.load(open(fid, encoding="utf-8")))
1059
- return fig
1060
-
1061
- def make_fig_lengths(tokenized_df, length_field):
1062
- fig_tok_length = px.histogram(
1063
- tokenized_df, x=length_field, marginal="rug", hover_data=[length_field]
1064
- )
1065
- return fig_tok_length
1066
-
1067
- def make_fig_labels(label_df, label_names, label_field):
1068
- labels = label_df[label_field].unique()
1069
- label_sums = [len(label_df[label_df[label_field] == label]) for label in labels]
1070
- fig_labels = px.pie(label_df, values=label_sums, names=label_names)
1071
- return fig_labels
1072
-
1073
-
1074
- def make_zipf_fig_ranked_word_list(vocab_df, unique_counts, unique_ranks):
1075
- ranked_words = {}
1076
- for count, rank in zip(unique_counts, unique_ranks):
1077
- vocab_df[vocab_df[CNT] == count]["rank"] = rank
1078
- ranked_words[rank] = ",".join(
1079
- vocab_df[vocab_df[CNT] == count].index.astype(str)
1080
- ) # Use the hovertext kw argument for hover text
1081
- ranked_words_list = [wrds for rank, wrds in sorted(ranked_words.items())]
1082
- return ranked_words_list
1083
-
1084
-
1085
- def make_npmi_fig(paired_results, subgroup_pair):
1086
- subgroup1, subgroup2 = subgroup_pair
1087
- UI_results = pd.DataFrame()
1088
- if "npmi-bias" in paired_results:
1089
- UI_results["npmi-bias"] = paired_results["npmi-bias"].astype(float)
1090
- UI_results[subgroup1 + "-npmi"] = paired_results["npmi"][
1091
- subgroup1 + "-npmi"
1092
- ].astype(float)
1093
- UI_results[subgroup1 + "-count"] = paired_results["count"][
1094
- subgroup1 + "-count"
1095
- ].astype(int)
1096
- if subgroup1 != subgroup2:
1097
- UI_results[subgroup2 + "-npmi"] = paired_results["npmi"][
1098
- subgroup2 + "-npmi"
1099
- ].astype(float)
1100
- UI_results[subgroup2 + "-count"] = paired_results["count"][
1101
- subgroup2 + "-count"
1102
- ].astype(int)
1103
- return UI_results.sort_values(by="npmi-bias", ascending=True)
1104
-
1105
-
1106
- def make_zipf_fig(vocab_counts_df, z):
1107
- zipf_counts = z.calc_zipf_counts(vocab_counts_df)
1108
- unique_counts = z.uniq_counts
1109
- unique_ranks = z.uniq_ranks
1110
- ranked_words_list = make_zipf_fig_ranked_word_list(
1111
- vocab_counts_df, unique_counts, unique_ranks
1112
- )
1113
- zmin = z.get_xmin()
1114
- logs.info("zipf counts is")
1115
- logs.info(zipf_counts)
1116
- layout = go.Layout(xaxis=dict(range=[0, 100]))
1117
- fig = go.Figure(
1118
- data=[
1119
- go.Bar(
1120
- x=z.uniq_ranks,
1121
- y=z.uniq_counts,
1122
- hovertext=ranked_words_list,
1123
- name="Word Rank Frequency",
1124
- )
1125
- ],
1126
- layout=layout,
1127
- )
1128
- fig.add_trace(
1129
- go.Scatter(
1130
- x=z.uniq_ranks[zmin : len(z.uniq_ranks)],
1131
- y=zipf_counts[zmin : len(z.uniq_ranks)],
1132
- hovertext=ranked_words_list[zmin : len(z.uniq_ranks)],
1133
- line=go.scatter.Line(color="crimson", width=3),
1134
- name="Zipf Predicted Frequency",
1135
- )
1136
- )
1137
- # Customize aspect
1138
- # fig.update_traces(marker_color='limegreen',
1139
- # marker_line_width=1.5, opacity=0.6)
1140
- fig.update_layout(title_text="Word Counts, Observed and Predicted by Zipf")
1141
- fig.update_layout(xaxis_title="Word Rank")
1142
- fig.update_layout(yaxis_title="Frequency")
1143
- fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.10))
1144
- return fig
1145
-
1146
-
1147
- def make_tree_plot(node_list, text_dset):
1148
- nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
1149
-
1150
- for nid, node in enumerate(node_list):
1151
- node["label"] = node.get(
1152
- "label",
1153
- f"{nid:2d} - {node['weight']:5d} items <br>"
1154
- + "<br>".join(
1155
- [
1156
- "> " + txt[:64] + ("..." if len(txt) >= 63 else "")
1157
- for txt in list(
1158
- set(text_dset.select(node["example_ids"])[OUR_TEXT_FIELD])
1159
- )[:5]
1160
- ]
1161
- ),
1162
- )
1163
-
1164
- # make plot nodes
1165
- # TODO: something more efficient than set to remove duplicates
1166
- labels = [node["label"] for node in node_list]
1167
-
1168
- root = node_list[0]
1169
- root["X"] = 0
1170
- root["Y"] = 0
1171
-
1172
- def rec_make_coordinates(node):
1173
- total_weight = 0
1174
- add_weight = len(node["example_ids"]) - sum(
1175
- [child["weight"] for child in node["children"]]
1176
- )
1177
- for child in node["children"]:
1178
- child["X"] = node["X"] + total_weight
1179
- child["Y"] = node["Y"] - 1
1180
- total_weight += child["weight"] + add_weight / len(node["children"])
1181
- rec_make_coordinates(child)
1182
-
1183
- rec_make_coordinates(root)
1184
-
1185
- E = [] # list of edges
1186
- Xn = []
1187
- Yn = []
1188
- Xe = []
1189
- Ye = []
1190
- for nid, node in enumerate(node_list):
1191
- Xn += [node["X"]]
1192
- Yn += [node["Y"]]
1193
- for child in node["children"]:
1194
- E += [(nid, nid_map[child["nid"]])]
1195
- Xe += [node["X"], child["X"], None]
1196
- Ye += [node["Y"], child["Y"], None]
1197
-
1198
- # make figure
1199
- fig = go.Figure()
1200
- fig.add_trace(
1201
- go.Scatter(
1202
- x=Xe,
1203
- y=Ye,
1204
- mode="lines",
1205
- line=dict(color="rgb(210,210,210)", width=1),
1206
- hoverinfo="none",
1207
- )
1208
- )
1209
- fig.add_trace(
1210
- go.Scatter(
1211
- x=Xn,
1212
- y=Yn,
1213
- mode="markers",
1214
- name="nodes",
1215
- marker=dict(
1216
- symbol="circle-dot",
1217
- size=18,
1218
- color="#6175c1",
1219
- line=dict(color="rgb(50,50,50)", width=1)
1220
- # '#DB4551',
1221
- ),
1222
- text=labels,
1223
- hoverinfo="text",
1224
- opacity=0.8,
1225
- )
1226
- )
1227
- return fig
1228
-
1229
-
1230
- ## Input/Output ###
1231
-
1232
-
1233
- def define_subgroup_files(subgroup_list, pmi_cache_path):
1234
- """
1235
- Sets the file ids for the input identity terms
1236
- :param subgroup_list: List of identity terms
1237
- :return:
1238
- """
1239
- subgroup_files = {}
1240
- for subgroup in subgroup_list:
1241
- # TODO: Should the pmi, npmi, and count just be one file?
1242
- subgroup_npmi_fid = pjoin(pmi_cache_path, subgroup + "_npmi.csv")
1243
- subgroup_pmi_fid = pjoin(pmi_cache_path, subgroup + "_pmi.csv")
1244
- subgroup_cooc_fid = pjoin(pmi_cache_path, subgroup + "_vocab_cooc.csv")
1245
- subgroup_files[subgroup] = (
1246
- subgroup_npmi_fid,
1247
- subgroup_pmi_fid,
1248
- subgroup_cooc_fid,
1249
- )
1250
- return subgroup_files
1251
-
1252
-
1253
- ## Input/Output ##
1254
-
1255
-
1256
- def intersect_dfs(df_dict):
1257
- started = 0
1258
- new_df = None
1259
- for key, df in df_dict.items():
1260
- if df is None:
1261
- continue
1262
- for key2, df2 in df_dict.items():
1263
- if df2 is None:
1264
- continue
1265
- if key == key2:
1266
- continue
1267
- if started:
1268
- new_df = new_df.join(df2, how="inner", lsuffix="1", rsuffix="2")
1269
- else:
1270
- new_df = df.join(df2, how="inner", lsuffix="1", rsuffix="2")
1271
- started = 1
1272
- return new_df.copy()
1273
-
1274
-
1275
- def write_df(df, df_fid):
1276
- feather.write_feather(df, df_fid)
1277
-
1278
-
1279
- def write_json(json_dict, json_fid):
1280
- with open(json_fid, "w", encoding="utf-8") as f:
1281
- json.dump(json_dict, f)
1282
-
1283
- def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
1284
- """
1285
- Saves the calculated nPMI statistics to their output files.
1286
- Includes the npmi scores for each identity term, the pmi scores, and the
1287
- co-occurrence counts of the identity term with all the other words
1288
- :param subgroup: Identity term
1289
- :return:
1290
- """
1291
- subgroup_fids = subgroup_files[subgroup]
1292
- subgroup_npmi_fid, subgroup_pmi_fid, subgroup_cooc_fid = subgroup_fids
1293
- subgroup_dfs = subgroup_dict[subgroup]
1294
- subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df = subgroup_dfs
1295
- with open(subgroup_npmi_fid, "w+") as f:
1296
- subgroup_npmi_df.to_csv(f)
1297
- with open(subgroup_pmi_fid, "w+") as f:
1298
- subgroup_pmi_df.to_csv(f)
1299
- with open(subgroup_cooc_fid, "w+") as f:
1300
- subgroup_cooc_df.to_csv(f)
1301
-
1302
- def write_zipf_data(z, zipf_fid):
1303
- zipf_dict = {}
1304
- zipf_dict["xmin"] = int(z.xmin)
1305
- zipf_dict["xmax"] = int(z.xmax)
1306
- zipf_dict["alpha"] = float(z.alpha)
1307
- zipf_dict["ks_distance"] = float(z.distance)
1308
- zipf_dict["p-value"] = float(z.ks_test.pvalue)
1309
- zipf_dict["uniq_counts"] = [int(count) for count in z.uniq_counts]
1310
- zipf_dict["uniq_ranks"] = [int(rank) for rank in z.uniq_ranks]
1311
- with open(zipf_fid, "w+", encoding="utf-8") as f:
1312
- json.dump(zipf_dict, f)
1313
-