meg-huggingface commited on
Commit
a2ae370
1 Parent(s): 335424f

More modularizing; npmi and labels

Browse files
app.py CHANGED
@@ -118,9 +118,8 @@ def load_or_prepare(ds_args, show_embeddings, use_cache=False):
118
  if show_embeddings:
119
  logs.warning("Loading Embeddings")
120
  dstats.load_or_prepare_embeddings()
121
- # TODO: This has now been moved to calculation when the npmi widget is loaded.
122
- logs.warning("Loading Terms for nPMI")
123
- dstats.load_or_prepare_npmi_terms()
124
  logs.warning("Loading Zipf")
125
  dstats.load_or_prepare_zipf()
126
  return dstats
@@ -156,6 +155,8 @@ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
156
  # Embeddings widget
157
  dstats.load_or_prepare_embeddings()
158
  dstats.load_or_prepare_text_duplicates()
 
 
159
 
160
  def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
161
  """
@@ -179,17 +180,9 @@ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=T
179
  st_utils.expander_label_distribution(dstats.fig_labels, column_id)
180
  st_utils.expander_text_lengths(dstats, column_id)
181
  st_utils.expander_text_duplicates(dstats, column_id)
182
-
183
- # We do the loading of these after the others in order to have some time
184
- # to compute while the user works with the details above.
185
  # Uses an interaction; handled a bit differently than other widgets.
186
  logs.info("showing npmi widget")
187
- npmi_stats = dataset_statistics.nPMIStatisticsCacheClass(
188
- dstats, use_cache=use_cache
189
- )
190
- available_terms = npmi_stats.get_available_terms()
191
- st_utils.npmi_widget(
192
- column_id, available_terms, npmi_stats, _MIN_VOCAB_COUNT)
193
  logs.info("showing zipf")
194
  st_utils.expander_zipf(dstats.z, dstats.zipf_fig, column_id)
195
  if show_embeddings:
 
118
  if show_embeddings:
119
  logs.warning("Loading Embeddings")
120
  dstats.load_or_prepare_embeddings()
121
+ logs.warning("Loading nPMI")
122
+ dstats.load_or_prepare_npmi()
 
123
  logs.warning("Loading Zipf")
124
  dstats.load_or_prepare_zipf()
125
  return dstats
 
155
  # Embeddings widget
156
  dstats.load_or_prepare_embeddings()
157
  dstats.load_or_prepare_text_duplicates()
158
+ dstats.load_or_prepare_npmi()
159
+ dstats.load_or_prepare_zipf()
160
 
161
  def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
162
  """
 
180
  st_utils.expander_label_distribution(dstats.fig_labels, column_id)
181
  st_utils.expander_text_lengths(dstats, column_id)
182
  st_utils.expander_text_duplicates(dstats, column_id)
 
 
 
183
  # Uses an interaction; handled a bit differently than other widgets.
184
  logs.info("showing npmi widget")
185
+ st_utils.npmi_widget(dstats.npmi_stats, _MIN_VOCAB_COUNT, column_id)
 
 
 
 
 
186
  logs.info("showing zipf")
187
  st_utils.expander_zipf(dstats.z, dstats.zipf_fig, column_id)
188
  if show_embeddings:
data_measurements/dataset_statistics.py CHANGED
@@ -231,10 +231,6 @@ class DatasetStatisticsCacheClass:
231
  # nPMI
232
  # Holds a nPMIStatisticsCacheClass object
233
  self.npmi_stats = None
234
- # TODO: Users ideally can type in whatever words they want.
235
- self.termlist = _IDENTITY_TERMS
236
- # termlist terms that are available more than _MIN_VOCAB_COUNT times
237
- self.available_terms = _IDENTITY_TERMS
238
  # TODO: Have lowercase be an option for a user to set.
239
  self.to_lowercase = True
240
  # The minimum amount of times a word should occur to be included in
@@ -627,24 +623,27 @@ class DatasetStatisticsCacheClass:
627
  if save:
628
  write_plotly(self.fig_labels, self.fig_labels_fid)
629
  else:
630
- self.get_base_dataset()
631
- self.label_dset = self.dset.map(
632
- lambda examples: extract_field(
633
- examples, self.label_field, OUR_LABEL_FIELD
634
- ),
635
- batched=True,
636
- remove_columns=list(self.dset.features),
637
- )
638
- self.label_df = self.label_dset.to_pandas()
639
- self.fig_labels = make_fig_labels(
640
- self.label_df, self.label_names, OUR_LABEL_FIELD
641
- )
642
  if save:
643
  # save extracted label instances
644
  self.label_dset.save_to_disk(self.label_dset_fid)
645
  write_plotly(self.fig_labels, self.fig_labels_fid)
646
 
647
- def load_or_prepare_npmi_terms(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
  self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
649
  self.npmi_stats.load_or_prepare_npmi_terms()
650
 
@@ -693,7 +692,10 @@ class nPMIStatisticsCacheClass:
693
  # We need to preprocess everything.
694
  mkdir(self.pmi_cache_path)
695
  self.joint_npmi_df_dict = {}
696
- self.termlist = self.dstats.termlist
 
 
 
697
  logs.info(self.termlist)
698
  self.use_cache = use_cache
699
  # TODO: Let users specify
@@ -701,8 +703,6 @@ class nPMIStatisticsCacheClass:
701
  self.min_vocab_count = self.dstats.min_vocab_count
702
  self.subgroup_files = {}
703
  self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
704
- self.available_terms = self.dstats.available_terms
705
- logs.info(self.available_terms)
706
 
707
  def load_or_prepare_npmi_terms(self):
708
  """
 
231
  # nPMI
232
  # Holds a nPMIStatisticsCacheClass object
233
  self.npmi_stats = None
 
 
 
 
234
  # TODO: Have lowercase be an option for a user to set.
235
  self.to_lowercase = True
236
  # The minimum amount of times a word should occur to be included in
 
623
  if save:
624
  write_plotly(self.fig_labels, self.fig_labels_fid)
625
  else:
626
+ self.prepare_labels()
 
 
 
 
 
 
 
 
 
 
 
627
  if save:
628
  # save extracted label instances
629
  self.label_dset.save_to_disk(self.label_dset_fid)
630
  write_plotly(self.fig_labels, self.fig_labels_fid)
631
 
632
+ def prepare_labels(self):
633
+ self.get_base_dataset()
634
+ self.label_dset = self.dset.map(
635
+ lambda examples: extract_field(
636
+ examples, self.label_field, OUR_LABEL_FIELD
637
+ ),
638
+ batched=True,
639
+ remove_columns=list(self.dset.features),
640
+ )
641
+ self.label_df = self.label_dset.to_pandas()
642
+ self.fig_labels = make_fig_labels(
643
+ self.label_df, self.label_names, OUR_LABEL_FIELD
644
+ )
645
+
646
+ def load_or_prepare_npmi(self):
647
  self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
648
  self.npmi_stats.load_or_prepare_npmi_terms()
649
 
 
692
  # We need to preprocess everything.
693
  mkdir(self.pmi_cache_path)
694
  self.joint_npmi_df_dict = {}
695
+ # TODO: Users ideally can type in whatever words they want.
696
+ self.termlist = _IDENTITY_TERMS
697
+ # termlist terms that are available more than _MIN_VOCAB_COUNT times
698
+ self.available_terms = _IDENTITY_TERMS
699
  logs.info(self.termlist)
700
  self.use_cache = use_cache
701
  # TODO: Let users specify
 
703
  self.min_vocab_count = self.dstats.min_vocab_count
704
  self.subgroup_files = {}
705
  self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
 
 
706
 
707
  def load_or_prepare_npmi_terms(self):
708
  """
data_measurements/streamlit_utils.py CHANGED
@@ -273,7 +273,6 @@ def expander_text_duplicates(dstats, column_id):
273
  st.write(
274
  "### Here is the list of all the duplicated items and their counts in your dataset:"
275
  )
276
- # Eh...adding 1 because otherwise it looks too weird for duplicate counts when the value is just 1.
277
  if dstats.dup_counts_df is None:
278
  st.write("There are no duplicates in this dataset! 🥳")
279
  else:
@@ -393,7 +392,7 @@ with an ideal α value of 1."""
393
 
394
 
395
  ### Finally finally finally, show nPMI stuff.
396
- def npmi_widget(column_id, available_terms, npmi_stats, min_vocab):
397
  """
398
  Part of the main app, but uses a user interaction so pulled out as its own f'n.
399
  :param use_cache:
@@ -403,16 +402,16 @@ def npmi_widget(column_id, available_terms, npmi_stats, min_vocab):
403
  :return:
404
  """
405
  with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
406
- if len(available_terms) > 0:
407
  expander_npmi_description(min_vocab)
408
  st.markdown("-----")
409
  term1 = st.selectbox(
410
  f"What is the first term you want to select?{column_id}",
411
- available_terms,
412
  )
413
  term2 = st.selectbox(
414
  f"What is the second term you want to select?{column_id}",
415
- reversed(available_terms),
416
  )
417
  # We calculate/grab nPMI data based on a canonical (alphabetic)
418
  # subgroup ordering.
 
273
  st.write(
274
  "### Here is the list of all the duplicated items and their counts in your dataset:"
275
  )
 
276
  if dstats.dup_counts_df is None:
277
  st.write("There are no duplicates in this dataset! 🥳")
278
  else:
 
392
 
393
 
394
  ### Finally finally finally, show nPMI stuff.
395
+ def npmi_widget(npmi_stats, min_vocab, column_id):
396
  """
397
  Part of the main app, but uses a user interaction so pulled out as its own f'n.
398
  :param use_cache:
 
402
  :return:
403
  """
404
  with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
405
+ if len(npmi_stats.available_terms) > 0:
406
  expander_npmi_description(min_vocab)
407
  st.markdown("-----")
408
  term1 = st.selectbox(
409
  f"What is the first term you want to select?{column_id}",
410
+ npmi_stats.available_terms,
411
  )
412
  term2 = st.selectbox(
413
  f"What is the second term you want to select?{column_id}",
414
+ reversed(npmi_stats.available_terms),
415
  )
416
  # We calculate/grab nPMI data based on a canonical (alphabetic)
417
  # subgroup ordering.