meg-huggingface commited on
Commit
4b53042
1 Parent(s): 66693d5

Begins modularizing so that each widget can be independently loaded without having a requirement on the ordering of load_or_preparing in app.py. This means that each function corresponding to a widget will check if the variables it depends on have been calculated yet. If not, it will call back to calculate them. Because of the messiness this causes with passing the use_cache variable around, I've now set use_cache as a global variable, set when the DatasetStatisticsCacheClass is initialized, and removed the use_cache arguments appearing in nearly every function.

Browse files
app.py CHANGED
@@ -100,30 +100,63 @@ def load_or_prepare(ds_args, show_embeddings, use_cache=False):
100
  mkdir(CACHE_DIR)
101
  if use_cache:
102
  logs.warning("Using cache")
103
- dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args)
104
  logs.warning("Loading Dataset")
105
- dstats.load_or_prepare_dataset(use_cache=use_cache)
106
  logs.warning("Extracting Labels")
107
- dstats.load_or_prepare_labels(use_cache=use_cache)
108
  logs.warning("Computing Text Lengths")
109
- dstats.load_or_prepare_text_lengths(use_cache=use_cache)
 
 
110
  logs.warning("Extracting Vocabulary")
111
- dstats.load_or_prepare_vocab(use_cache=use_cache)
112
  logs.warning("Calculating General Statistics...")
113
- dstats.load_or_prepare_general_stats(use_cache=use_cache)
114
  logs.warning("Completed Calculation.")
115
  logs.warning("Calculating Fine-Grained Statistics...")
116
  if show_embeddings:
117
  logs.warning("Loading Embeddings")
118
- dstats.load_or_prepare_embeddings(use_cache=use_cache)
119
  print(dstats.fig_tree)
120
  # TODO: This has now been moved to calculation when the npmi widget is loaded.
121
  logs.warning("Loading Terms for nPMI")
122
- dstats.load_or_prepare_npmi_terms(use_cache=use_cache)
123
  logs.warning("Loading Zipf")
124
- dstats.load_or_prepare_zipf(use_cache=use_cache)
125
  return dstats
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
129
  """
@@ -144,7 +177,7 @@ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=T
144
  st_utils.expander_header(dstats, ds_name_to_dict, column_id)
145
  logs.info("showing general stats")
146
  st_utils.expander_general_stats(dstats, column_id)
147
- st_utils.expander_label_distribution(dstats.label_df, dstats.fig_labels, column_id)
148
  st_utils.expander_text_lengths(
149
  dstats.tokenized_df,
150
  dstats.fig_tok_length,
@@ -163,7 +196,7 @@ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=T
163
  npmi_stats = dataset_statistics.nPMIStatisticsCacheClass(
164
  dstats, use_cache=use_cache
165
  )
166
- available_terms = npmi_stats.get_available_terms(use_cache=use_cache)
167
  st_utils.npmi_widget(
168
  column_id, available_terms, npmi_stats, _MIN_VOCAB_COUNT, use_cache=use_cache
169
  )
@@ -190,7 +223,7 @@ def main():
190
  compare_mode = st.sidebar.checkbox("Comparison mode")
191
 
192
  # When not doing new development, use the cache.
193
- use_cache = True
194
  show_embeddings = st.sidebar.checkbox("Show embeddings")
195
  # List of datasets for which embeddings are hard to compute:
196
 
 
100
  mkdir(CACHE_DIR)
101
  if use_cache:
102
  logs.warning("Using cache")
103
+ dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
104
  logs.warning("Loading Dataset")
105
+ dstats.load_or_prepare_dataset()
106
  logs.warning("Extracting Labels")
107
+ dstats.load_or_prepare_labels()
108
  logs.warning("Computing Text Lengths")
109
+ dstats.load_or_prepare_text_lengths()
110
+ logs.warning("Computing Duplicates")
111
+ dstats.load_or_prepare_text_duplicates()
112
  logs.warning("Extracting Vocabulary")
113
+ dstats.load_or_prepare_vocab()
114
  logs.warning("Calculating General Statistics...")
115
+ dstats.load_or_prepare_general_stats()
116
  logs.warning("Completed Calculation.")
117
  logs.warning("Calculating Fine-Grained Statistics...")
118
  if show_embeddings:
119
  logs.warning("Loading Embeddings")
120
+ dstats.load_or_prepare_embeddings()
121
  print(dstats.fig_tree)
122
  # TODO: This has now been moved to calculation when the npmi widget is loaded.
123
  logs.warning("Loading Terms for nPMI")
124
+ dstats.load_or_prepare_npmi_terms()
125
  logs.warning("Loading Zipf")
126
+ dstats.load_or_prepare_zipf()
127
  return dstats
128
 
129
+ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
130
+ """
131
+ Loader specifically for the widgets used in the app.
132
+ Args:
133
+ ds_args:
134
+ show_embeddings:
135
+ use_cache:
136
+
137
+ Returns:
138
+
139
+ """
140
+ if not isdir(CACHE_DIR):
141
+ logs.warning("Creating cache")
142
+ # We need to preprocess everything.
143
+ # This should eventually all go into a prepare_dataset CLI
144
+ mkdir(CACHE_DIR)
145
+ if use_cache:
146
+ logs.warning("Using cache")
147
+ dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
148
+ # Header widget
149
+ dstats.load_or_prepare_dset_peek()
150
+ # General stats widget
151
+ dstats.load_or_prepare_general_stats()
152
+ # Labels widget
153
+ dstats.load_or_prepare_labels()
154
+ # Text lengths widget
155
+ dstats.load_or_prepare_text_lengths()
156
+ if show_embeddings:
157
+ # Embeddings widget
158
+ dstats.load_or_prepare_embeddings()
159
+ dstats.load_or_prepare_text_duplicates()
160
 
161
  def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
162
  """
 
177
  st_utils.expander_header(dstats, ds_name_to_dict, column_id)
178
  logs.info("showing general stats")
179
  st_utils.expander_general_stats(dstats, column_id)
180
+ st_utils.expander_label_distribution(dstats.fig_labels, column_id)
181
  st_utils.expander_text_lengths(
182
  dstats.tokenized_df,
183
  dstats.fig_tok_length,
 
196
  npmi_stats = dataset_statistics.nPMIStatisticsCacheClass(
197
  dstats, use_cache=use_cache
198
  )
199
+ available_terms = npmi_stats.get_available_terms()
200
  st_utils.npmi_widget(
201
  column_id, available_terms, npmi_stats, _MIN_VOCAB_COUNT, use_cache=use_cache
202
  )
 
223
  compare_mode = st.sidebar.checkbox("Comparison mode")
224
 
225
  # When not doing new development, use the cache.
226
+ use_cache = False
227
  show_embeddings = st.sidebar.checkbox("Show embeddings")
228
  # List of datasets for which embeddings are hard to compute:
229
 
data_measurements/dataset_statistics.py CHANGED
@@ -159,6 +159,7 @@ class DatasetStatisticsCacheClass:
159
  label_field,
160
  label_names,
161
  calculation=None,
 
162
  ):
163
  # This is only used for standalone runs for each kind of measurement.
164
  self.calculation = calculation
@@ -168,6 +169,8 @@ class DatasetStatisticsCacheClass:
168
  self.our_tokenized_field = TOKENIZED_FIELD
169
  self.our_embedding_field = EMBEDDING_FIELD
170
  self.cache_dir = cache_dir
 
 
171
  ### What are we analyzing?
172
  # name of the Hugging Face dataset
173
  self.dset_name = dset_name
@@ -285,20 +288,19 @@ class DatasetStatisticsCacheClass:
285
  use_streaming=True,
286
  )
287
 
288
- def load_or_prepare_general_stats(self, use_cache=False, save=True):
289
  """
290
  Content for expander_general_stats widget.
291
  Provides statistics for total words, total open words,
292
  the sorted top vocab, the NaN count, and the duplicate count.
293
  Args:
294
- use_cache:
295
 
296
  Returns:
297
 
298
  """
299
  # General statistics
300
  if (
301
- use_cache
302
  and exists(self.general_stats_fid)
303
  and exists(self.dup_counts_df_fid)
304
  and exists(self.sorted_top_vocab_df_fid)
@@ -320,10 +322,10 @@ class DatasetStatisticsCacheClass:
320
  write_json(self.general_stats_dict, self.general_stats_fid)
321
 
322
 
323
- def load_or_prepare_text_lengths(self, use_cache=False, save=True):
324
  # TODO: Everything here can be read from cache; it's in a transitory
325
  # state atm where just the fig is cached. Clean up.
326
- if use_cache and exists(self.fig_tok_length_fid):
327
  self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
328
  if self.tokenized_df is None:
329
  self.tokenized_df = self.do_tokenization()
@@ -340,18 +342,18 @@ class DatasetStatisticsCacheClass:
340
  if save:
341
  write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
342
 
343
- def load_or_prepare_embeddings(self, use_cache=False, save=True):
344
- if use_cache and exists(self.node_list_fid) and exists(self.fig_tree_fid):
345
  self.node_list = torch.load(self.node_list_fid)
346
  self.fig_tree = read_plotly(self.fig_tree_fid)
347
- elif use_cache and exists(self.node_list_fid):
348
  self.node_list = torch.load(self.node_list_fid)
349
  self.fig_tree = make_tree_plot(self.node_list,
350
  self.text_dset)
351
  if save:
352
  write_plotly(self.fig_tree, self.fig_tree_fid)
353
  else:
354
- self.embeddings = Embeddings(self, use_cache=use_cache)
355
  self.embeddings.make_hierarchical_clustering()
356
  self.node_list = self.embeddings.node_list
357
  self.fig_tree = make_tree_plot(self.node_list,
@@ -361,15 +363,15 @@ class DatasetStatisticsCacheClass:
361
  write_plotly(self.fig_tree, self.fig_tree_fid)
362
 
363
  # get vocab with word counts
364
- def load_or_prepare_vocab(self, use_cache=True, save=True):
365
  """
366
  Calculates the vocabulary count from the tokenized text.
367
  The resulting dataframes may be used in nPMI calculations, zipf, etc.
368
- :param use_cache:
369
  :return:
370
  """
371
  if (
372
- use_cache
373
  and exists(self.vocab_counts_df_fid)
374
  ):
375
  logs.info("Reading vocab from cache")
@@ -400,10 +402,23 @@ class DatasetStatisticsCacheClass:
400
  # Handling for changes in how the index is saved.
401
  self.vocab_counts_df = self._set_idx_col_names(self.vocab_counts_df)
402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  def load_general_stats(self):
404
  self.general_stats_dict = json.load(open(self.general_stats_fid, encoding="utf-8"))
405
- with open(self.dup_counts_df_fid, "rb") as f:
406
- self.dup_counts_df = feather.read_feather(f)
407
  with open(self.sorted_top_vocab_df_fid, "rb") as f:
408
  self.sorted_top_vocab_df = feather.read_feather(f)
409
  self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
@@ -421,20 +436,10 @@ class DatasetStatisticsCacheClass:
421
  self.sorted_top_vocab_df = self.vocab_counts_filtered_df.sort_values(
422
  "count", ascending=False
423
  ).head(_TOP_N)
424
- print('basics')
425
  self.total_words = len(self.vocab_counts_df)
426
  self.total_open_words = len(self.vocab_counts_filtered_df)
427
  self.text_nan_count = int(self.tokenized_df.isnull().sum().sum())
428
- dup_df = self.tokenized_df[self.tokenized_df.duplicated([OUR_TEXT_FIELD])]
429
- print('dup df')
430
- self.dup_counts_df = pd.DataFrame(
431
- dup_df.pivot_table(
432
- columns=[OUR_TEXT_FIELD], aggfunc="size"
433
- ).sort_values(ascending=False),
434
- columns=[CNT],
435
- )
436
- print('deddup df')
437
- self.dup_counts_df[OUR_TEXT_FIELD] = self.dup_counts_df.index.copy()
438
  self.dedup_total = sum(self.dup_counts_df[CNT])
439
  self.general_stats_dict = {
440
  TOT_WORDS: self.total_words,
@@ -443,28 +448,40 @@ class DatasetStatisticsCacheClass:
443
  DEDUP_TOT: self.dedup_total,
444
  }
445
 
446
- def load_or_prepare_dataset(self, use_cache=True, save=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  """
448
  Prepares the HF datasets and data frames containing the untokenized and
449
  tokenized text as well as the label values.
450
  self.tokenized_df is used further for calculating text lengths,
451
  word counts, etc.
452
  Args:
453
- use_cache: Used stored data if there; otherwise calculate afresh
454
  save: Store the calculated data to disk.
455
 
456
  Returns:
457
 
458
  """
459
  logs.info("Doing text dset.")
460
- self.load_or_prepare_text_dset(use_cache, save)
461
  logs.info("Doing tokenized dataframe")
462
- self.load_or_prepare_tokenized_df(use_cache, save)
463
  logs.info("Doing dataset peek")
464
- self.load_or_prepare_dset_peek(save, use_cache)
465
 
466
- def load_or_prepare_dset_peek(self, save, use_cache):
467
- if use_cache and exists(self.dset_peek_fid):
468
  with open(self.dset_peek_fid, "r") as f:
469
  self.dset_peek = json.load(f)["dset peek"]
470
  else:
@@ -472,10 +489,10 @@ class DatasetStatisticsCacheClass:
472
  self.get_base_dataset()
473
  self.dset_peek = self.dset[:100]
474
  if save:
475
- write_json({"dset_peek": self.dset_peek}, self.dset_peek_fid)
476
 
477
- def load_or_prepare_tokenized_df(self, use_cache, save):
478
- if (use_cache and exists(self.tokenized_df_fid)):
479
  self.tokenized_df = feather.read_feather(self.tokenized_df_fid)
480
  else:
481
  # tokenize all text instances
@@ -485,8 +502,8 @@ class DatasetStatisticsCacheClass:
485
  # save tokenized text
486
  write_df(self.tokenized_df, self.tokenized_df_fid)
487
 
488
- def load_or_prepare_text_dset(self, use_cache, save):
489
- if (use_cache and exists(self.text_dset_fid)):
490
  # load extracted text
491
  self.text_dset = load_from_disk(self.text_dset_fid)
492
  logs.warning("Loaded dataset from disk")
@@ -515,6 +532,8 @@ class DatasetStatisticsCacheClass:
515
  Tokenizes the dataset
516
  :return:
517
  """
 
 
518
  sent_tokenizer = self.cvec.build_tokenizer()
519
 
520
  def tokenize_batch(examples):
@@ -544,19 +563,18 @@ class DatasetStatisticsCacheClass:
544
  """
545
  self.label_field = label_field
546
 
547
- def load_or_prepare_labels(self, use_cache=False, save=True):
548
  # TODO: This is in a transitory state for creating fig cache.
549
  # Clean up to be caching and reading everything correctly.
550
  """
551
  Extracts labels from the Dataset
552
- :param use_cache:
553
  :return:
554
  """
555
  # extracted labels
556
  if len(self.label_field) > 0:
557
- if use_cache and exists(self.fig_labels_fid):
558
  self.fig_labels = read_plotly(self.fig_labels_fid)
559
- elif use_cache and exists(self.label_dset_fid):
560
  # load extracted labels
561
  self.label_dset = load_from_disk(self.label_dset_fid)
562
  self.label_df = self.label_dset.to_pandas()
@@ -583,21 +601,21 @@ class DatasetStatisticsCacheClass:
583
  self.label_dset.save_to_disk(self.label_dset_fid)
584
  write_plotly(self.fig_labels, self.fig_labels_fid)
585
 
586
- def load_or_prepare_npmi_terms(self, use_cache=False):
587
- self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=use_cache)
588
  self.npmi_stats.load_or_prepare_npmi_terms()
589
 
590
- def load_or_prepare_zipf(self, use_cache=False, save=True):
591
  # TODO: Current UI only uses the fig, meaning the self.z here is irrelevant
592
  # when only reading from cache. Either the UI should use it, or it should
593
  # be removed when reading in cache
594
- if use_cache and exists(self.zipf_fig_fid) and exists(self.zipf_fid):
595
  with open(self.zipf_fid, "r") as f:
596
  zipf_dict = json.load(f)
597
  self.z = Zipf()
598
  self.z.load(zipf_dict)
599
  self.zipf_fig = read_plotly(self.zipf_fig_fid)
600
- elif use_cache and exists(self.zipf_fid):
601
  # TODO: Read zipf data so that the vocab is there.
602
  with open(self.zipf_fid, "r") as f:
603
  zipf_dict = json.load(f)
@@ -643,17 +661,16 @@ class nPMIStatisticsCacheClass:
643
  self.available_terms = self.dstats.available_terms
644
  logs.info(self.available_terms)
645
 
646
- def load_or_prepare_npmi_terms(self, use_cache=False):
647
  """
648
  Figures out what identity terms the user can select, based on whether
649
  they occur more than self.min_vocab_count times
650
- :param use_cache:
651
  :return: Identity terms occurring at least self.min_vocab_count times.
652
  """
653
  # TODO: Add the user's ability to select subgroups.
654
  # TODO: Make min_vocab_count here value selectable by the user.
655
  if (
656
- use_cache
657
  and exists(self.npmi_terms_fid)
658
  and json.load(open(self.npmi_terms_fid))["available terms"] != []
659
  ):
@@ -676,7 +693,7 @@ class nPMIStatisticsCacheClass:
676
  self.available_terms = available_terms
677
  return available_terms
678
 
679
- def load_or_prepare_joint_npmi(self, subgroup_pair, use_cache=True):
680
  """
681
  Run on-the fly, while the app is already open,
682
  as it depends on the subgroup terms that the user chooses
@@ -695,7 +712,7 @@ class nPMIStatisticsCacheClass:
695
  subgroup_files = define_subgroup_files(subgroup_pair, self.pmi_cache_path)
696
  # Defines the filenames for the cache files from the selected subgroups.
697
  # Get as much precomputed data as we can.
698
- if use_cache and exists(joint_npmi_fid):
699
  # When everything is already computed for the selected subgroups.
700
  logs.info("Loading cached joint npmi")
701
  joint_npmi_df = self.load_joint_npmi_df(joint_npmi_fid)
@@ -850,8 +867,8 @@ class nPMIStatisticsCacheClass:
850
  csv_df.columns = [calc_str]
851
  return csv_df
852
 
853
- def get_available_terms(self, use_cache=False):
854
- return self.load_or_prepare_npmi_terms(use_cache=use_cache)
855
 
856
  def dummy(doc):
857
  return doc
 
159
  label_field,
160
  label_names,
161
  calculation=None,
162
+ use_cache=False,
163
  ):
164
  # This is only used for standalone runs for each kind of measurement.
165
  self.calculation = calculation
 
169
  self.our_tokenized_field = TOKENIZED_FIELD
170
  self.our_embedding_field = EMBEDDING_FIELD
171
  self.cache_dir = cache_dir
172
+ # Use stored data if there; otherwise calculate afresh
173
+ self.use_cache = use_cache
174
  ### What are we analyzing?
175
  # name of the Hugging Face dataset
176
  self.dset_name = dset_name
 
288
  use_streaming=True,
289
  )
290
 
291
+ def load_or_prepare_general_stats(self, save=True):
292
  """
293
  Content for expander_general_stats widget.
294
  Provides statistics for total words, total open words,
295
  the sorted top vocab, the NaN count, and the duplicate count.
296
  Args:
 
297
 
298
  Returns:
299
 
300
  """
301
  # General statistics
302
  if (
303
+ self.use_cache
304
  and exists(self.general_stats_fid)
305
  and exists(self.dup_counts_df_fid)
306
  and exists(self.sorted_top_vocab_df_fid)
 
322
  write_json(self.general_stats_dict, self.general_stats_fid)
323
 
324
 
325
+ def load_or_prepare_text_lengths(self, save=True):
326
  # TODO: Everything here can be read from cache; it's in a transitory
327
  # state atm where just the fig is cached. Clean up.
328
+ if self.use_cache and exists(self.fig_tok_length_fid):
329
  self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
330
  if self.tokenized_df is None:
331
  self.tokenized_df = self.do_tokenization()
 
342
  if save:
343
  write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
344
 
345
+ def load_or_prepare_embeddings(self, save=True):
346
+ if self.use_cache and exists(self.node_list_fid) and exists(self.fig_tree_fid):
347
  self.node_list = torch.load(self.node_list_fid)
348
  self.fig_tree = read_plotly(self.fig_tree_fid)
349
+ elif self.use_cache and exists(self.node_list_fid):
350
  self.node_list = torch.load(self.node_list_fid)
351
  self.fig_tree = make_tree_plot(self.node_list,
352
  self.text_dset)
353
  if save:
354
  write_plotly(self.fig_tree, self.fig_tree_fid)
355
  else:
356
+ self.embeddings = Embeddings(self, use_cache=self.use_cache)
357
  self.embeddings.make_hierarchical_clustering()
358
  self.node_list = self.embeddings.node_list
359
  self.fig_tree = make_tree_plot(self.node_list,
 
363
  write_plotly(self.fig_tree, self.fig_tree_fid)
364
 
365
  # get vocab with word counts
366
+ def load_or_prepare_vocab(self, save=True):
367
  """
368
  Calculates the vocabulary count from the tokenized text.
369
  The resulting dataframes may be used in nPMI calculations, zipf, etc.
370
+ :param
371
  :return:
372
  """
373
  if (
374
+ self.use_cache
375
  and exists(self.vocab_counts_df_fid)
376
  ):
377
  logs.info("Reading vocab from cache")
 
402
  # Handling for changes in how the index is saved.
403
  self.vocab_counts_df = self._set_idx_col_names(self.vocab_counts_df)
404
 
405
+ def load_or_prepare_text_duplicates(self, save=True):
406
+ if self.use_cache and exists(self.dup_counts_df_fid):
407
+ with open(self.dup_counts_df_fid, "rb") as f:
408
+ self.dup_counts_df = feather.read_feather(f)
409
+ elif self.dup_counts_df is None:
410
+ self.prepare_text_duplicates()
411
+ if save:
412
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
413
+ else:
414
+ # This happens when self.dup_counts_df is already defined;
415
+ # This happens when general_statistics were calculated first,
416
+ # since general statistics requires the number of duplicates
417
+ if save:
418
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
419
+
420
  def load_general_stats(self):
421
  self.general_stats_dict = json.load(open(self.general_stats_fid, encoding="utf-8"))
 
 
422
  with open(self.sorted_top_vocab_df_fid, "rb") as f:
423
  self.sorted_top_vocab_df = feather.read_feather(f)
424
  self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
 
436
  self.sorted_top_vocab_df = self.vocab_counts_filtered_df.sort_values(
437
  "count", ascending=False
438
  ).head(_TOP_N)
 
439
  self.total_words = len(self.vocab_counts_df)
440
  self.total_open_words = len(self.vocab_counts_filtered_df)
441
  self.text_nan_count = int(self.tokenized_df.isnull().sum().sum())
442
+ self.prepare_text_duplicates()
 
 
 
 
 
 
 
 
 
443
  self.dedup_total = sum(self.dup_counts_df[CNT])
444
  self.general_stats_dict = {
445
  TOT_WORDS: self.total_words,
 
448
  DEDUP_TOT: self.dedup_total,
449
  }
450
 
451
+ def prepare_text_duplicates(self):
452
+ if self.tokenized_df is None:
453
+ self.load_or_prepare_tokenized_df()
454
+ dup_df = self.tokenized_df[
455
+ self.tokenized_df.duplicated([OUR_TEXT_FIELD])]
456
+ self.dup_counts_df = pd.DataFrame(
457
+ dup_df.pivot_table(
458
+ columns=[OUR_TEXT_FIELD], aggfunc="size"
459
+ ).sort_values(ascending=False),
460
+ columns=[CNT],
461
+ )
462
+ self.dup_counts_df[OUR_TEXT_FIELD] = self.dup_counts_df.index.copy()
463
+
464
+ def load_or_prepare_dataset(self, save=True):
465
  """
466
  Prepares the HF datasets and data frames containing the untokenized and
467
  tokenized text as well as the label values.
468
  self.tokenized_df is used further for calculating text lengths,
469
  word counts, etc.
470
  Args:
 
471
  save: Store the calculated data to disk.
472
 
473
  Returns:
474
 
475
  """
476
  logs.info("Doing text dset.")
477
+ self.load_or_prepare_text_dset(save)
478
  logs.info("Doing tokenized dataframe")
479
+ self.load_or_prepare_tokenized_df(save)
480
  logs.info("Doing dataset peek")
481
+ self.load_or_prepare_dset_peek(save)
482
 
483
+ def load_or_prepare_dset_peek(self, save=True):
484
+ if self.use_cache and exists(self.dset_peek_fid):
485
  with open(self.dset_peek_fid, "r") as f:
486
  self.dset_peek = json.load(f)["dset peek"]
487
  else:
 
489
  self.get_base_dataset()
490
  self.dset_peek = self.dset[:100]
491
  if save:
492
+ write_json({"dset peek": self.dset_peek}, self.dset_peek_fid)
493
 
494
+ def load_or_prepare_tokenized_df(self, save=True):
495
+ if (self.use_cache and exists(self.tokenized_df_fid)):
496
  self.tokenized_df = feather.read_feather(self.tokenized_df_fid)
497
  else:
498
  # tokenize all text instances
 
502
  # save tokenized text
503
  write_df(self.tokenized_df, self.tokenized_df_fid)
504
 
505
+ def load_or_prepare_text_dset(self, save=True):
506
+ if (self.use_cache and exists(self.text_dset_fid)):
507
  # load extracted text
508
  self.text_dset = load_from_disk(self.text_dset_fid)
509
  logs.warning("Loaded dataset from disk")
 
532
  Tokenizes the dataset
533
  :return:
534
  """
535
+ if self.text_dset is None:
536
+ self.load_or_prepare_text_dset()
537
  sent_tokenizer = self.cvec.build_tokenizer()
538
 
539
  def tokenize_batch(examples):
 
563
  """
564
  self.label_field = label_field
565
 
566
+ def load_or_prepare_labels(self, save=True):
567
  # TODO: This is in a transitory state for creating fig cache.
568
  # Clean up to be caching and reading everything correctly.
569
  """
570
  Extracts labels from the Dataset
 
571
  :return:
572
  """
573
  # extracted labels
574
  if len(self.label_field) > 0:
575
+ if self.use_cache and exists(self.fig_labels_fid):
576
  self.fig_labels = read_plotly(self.fig_labels_fid)
577
+ elif self.use_cache and exists(self.label_dset_fid):
578
  # load extracted labels
579
  self.label_dset = load_from_disk(self.label_dset_fid)
580
  self.label_df = self.label_dset.to_pandas()
 
601
  self.label_dset.save_to_disk(self.label_dset_fid)
602
  write_plotly(self.fig_labels, self.fig_labels_fid)
603
 
604
+ def load_or_prepare_npmi_terms(self):
605
+ self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
606
  self.npmi_stats.load_or_prepare_npmi_terms()
607
 
608
+ def load_or_prepare_zipf(self, save=True):
609
  # TODO: Current UI only uses the fig, meaning the self.z here is irrelevant
610
  # when only reading from cache. Either the UI should use it, or it should
611
  # be removed when reading in cache
612
+ if self.use_cache and exists(self.zipf_fig_fid) and exists(self.zipf_fid):
613
  with open(self.zipf_fid, "r") as f:
614
  zipf_dict = json.load(f)
615
  self.z = Zipf()
616
  self.z.load(zipf_dict)
617
  self.zipf_fig = read_plotly(self.zipf_fig_fid)
618
+ elif self.use_cache and exists(self.zipf_fid):
619
  # TODO: Read zipf data so that the vocab is there.
620
  with open(self.zipf_fid, "r") as f:
621
  zipf_dict = json.load(f)
 
661
  self.available_terms = self.dstats.available_terms
662
  logs.info(self.available_terms)
663
 
664
+ def load_or_prepare_npmi_terms(self):
665
  """
666
  Figures out what identity terms the user can select, based on whether
667
  they occur more than self.min_vocab_count times
 
668
  :return: Identity terms occurring at least self.min_vocab_count times.
669
  """
670
  # TODO: Add the user's ability to select subgroups.
671
  # TODO: Make min_vocab_count here value selectable by the user.
672
  if (
673
+ self.use_cache
674
  and exists(self.npmi_terms_fid)
675
  and json.load(open(self.npmi_terms_fid))["available terms"] != []
676
  ):
 
693
  self.available_terms = available_terms
694
  return available_terms
695
 
696
+ def load_or_prepare_joint_npmi(self, subgroup_pair):
697
  """
698
  Run on-the fly, while the app is already open,
699
  as it depends on the subgroup terms that the user chooses
 
712
  subgroup_files = define_subgroup_files(subgroup_pair, self.pmi_cache_path)
713
  # Defines the filenames for the cache files from the selected subgroups.
714
  # Get as much precomputed data as we can.
715
+ if self.use_cache and exists(joint_npmi_fid):
716
  # When everything is already computed for the selected subgroups.
717
  logs.info("Loading cached joint npmi")
718
  joint_npmi_df = self.load_joint_npmi_df(joint_npmi_fid)
 
867
  csv_df.columns = [calc_str]
868
  return csv_df
869
 
870
+ def get_available_terms(self):
871
+ return self.load_or_prepare_npmi_terms()
872
 
873
  def dummy(doc):
874
  return doc
data_measurements/streamlit_utils.py CHANGED
@@ -136,12 +136,12 @@ def expander_general_stats(dstats, column_id):
136
 
137
 
138
  ### Show the label distribution from the datasets
139
- def expander_label_distribution(label_df, fig_labels, column_id):
140
  with st.expander(f"Label Distribution{column_id}", expanded=False):
141
  st.caption(
142
  "Use this widget to see how balanced the labels in your dataset are."
143
  )
144
- if label_df is not None:
145
  st.plotly_chart(fig_labels, use_container_width=True)
146
  else:
147
  st.markdown("No labels were found in the dataset")
@@ -285,7 +285,7 @@ def expander_text_duplicates(dstats, column_id):
285
  "### Here is the list of all the duplicated items and their counts in your dataset:"
286
  )
287
  # Eh...adding 1 because otherwise it looks too weird for duplicate counts when the value is just 1.
288
- if len(dstats.dup_counts_df) == 0:
289
  st.write("There are no duplicates in this dataset! 🥳")
290
  else:
291
  gb = GridOptionsBuilder.from_dataframe(dstats.dup_counts_df)
 
136
 
137
 
138
  ### Show the label distribution from the datasets
139
+ def expander_label_distribution(fig_labels, column_id):
140
  with st.expander(f"Label Distribution{column_id}", expanded=False):
141
  st.caption(
142
  "Use this widget to see how balanced the labels in your dataset are."
143
  )
144
+ if fig_labels is not None:
145
  st.plotly_chart(fig_labels, use_container_width=True)
146
  else:
147
  st.markdown("No labels were found in the dataset")
 
285
  "### Here is the list of all the duplicated items and their counts in your dataset:"
286
  )
287
  # Eh...adding 1 because otherwise it looks too weird for duplicate counts when the value is just 1.
288
+ if dstats.dup_counts_df is None:
289
  st.write("There are no duplicates in this dataset! 🥳")
290
  else:
291
  gb = GridOptionsBuilder.from_dataframe(dstats.dup_counts_df)